Skip to content

Commit

Permalink
fix up network naming; add example to load network and run inference …
Browse files Browse the repository at this point in the history
…in notebook
  • Loading branch information
misko committed Nov 3, 2024
1 parent f577855 commit 6d1501c
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 43 deletions.
14 changes: 3 additions & 11 deletions spf/model_training_and_inference/models/beamsegnet.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import math
from collections import OrderedDict
from math import sqrt

import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import sigmoid

from spf.dataset.spf_dataset import v5_thetas_to_targets
from spf.rf import (
pi_norm, # torch_circular_mean_weighted,
reduce_theta_to_positive_y,
torch_circular_mean,
torch_pi_norm,
)
from spf.rf import reduce_theta_to_positive_y, torch_circular_mean, torch_pi_norm


class ConvNet(nn.Module):
Expand Down Expand Up @@ -315,10 +307,10 @@ def __init__(
net_layout += [nn.Linear(hidden, outputs)]
else:
raise ValueError(f"Norm not implemented {norm}")
self.net = nn.Sequential(*net_layout)
self.ffnn_internal_net = nn.Sequential(*net_layout)

def forward(self, x):
return self.net(x)
return self.ffnn_internal_net(x)


class HalfPiEncoding(nn.Module):
Expand Down
12 changes: 8 additions & 4 deletions spf/model_training_and_inference/models/single_point_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class SinglePointWithBeamformer(nn.Module):
def __init__(self, model_config, global_config):
super().__init__()
self.prepare_input = PrepareInput(model_config, global_config)
self.net = FFNN(
self.single_point_with_beamformer_ffnn = FFNN(
inputs=self.prepare_input.inputs,
depth=model_config["depth"], # 4
hidden=model_config["hidden"], # 128
Expand All @@ -92,7 +92,11 @@ def forward(self, batch):
# first dim odd / even is the radios
return {
"single": torch.nn.functional.normalize(
self.net(self.prepare_input.prepare_input(batch)).abs(), dim=2, p=1
self.single_point_with_beamformer_ffnn(
self.prepare_input.prepare_input(batch)
).abs(),
dim=2,
p=1,
)
}

Expand All @@ -104,7 +108,7 @@ def __init__(self, model_config, global_config):
model_config["single"], global_config
)
self.detach = model_config.get("detach", True)
self.net = FFNN(
self.paired_single_point_with_beamformer_ffnn = FFNN(
inputs=global_config["nthetas"] * 2,
depth=model_config["depth"], # 4
hidden=model_config["hidden"], # 128
Expand All @@ -123,7 +127,7 @@ def forward(self, batch):
single_radio_estimates, self.detach
)

x = self.net(
x = self.paired_single_point_with_beamformer_ffnn(
torch.concatenate(
[single_radio_estimates_input[::2], single_radio_estimates_input[1::2]],
dim=2,
Expand Down
123 changes: 123 additions & 0 deletions spf/notebooks/load_model_checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from spf.scripts.train_single_point import (\n",
" load_checkpoint,\n",
" load_config_from_fn,\n",
" load_model,\n",
")\n",
"\n",
"\n",
"def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn):\n",
" config = load_config_from_fn(config_fn)\n",
" config[\"optim\"][\"checkpoint\"] = checkpoint_fn\n",
" m = load_model(config[\"model\"], config[\"global\"]).to(config[\"optim\"][\"device\"])\n",
" m, _, _, _, _ = load_checkpoint(\n",
" checkpoint_fn=config[\"optim\"][\"checkpoint\"],\n",
" config=config,\n",
" model=m,\n",
" optimizer=None,\n",
" scheduler=None,\n",
" force_load=True,\n",
" )\n",
" return m, config\n",
"\n",
"\n",
"def convert_datasets_config_to_inference(datasets_config, ds_fn):\n",
" datasets_config = datasets_config.copy()\n",
" datasets_config.update(\n",
" {\n",
" \"batch_size\": 1,\n",
" \"flip\": False,\n",
" \"double_flip\": False,\n",
" \"precompute_cache\": \"/home/mouse9911/precompute_cache_chunk16_sept\",\n",
" \"shuffle\": False,\n",
" \"skip_qc\": True,\n",
" \"snapshots_adjacent_stride\": 1,\n",
" \"train_snapshots_per_session\": 1,\n",
" \"val_snapshots_per_session\": 1,\n",
" \"random_snapshot_size\": False,\n",
" \"snapshots_stride\": 1,\n",
" \"train_paths\": [ds_fn],\n",
" \"train_on_val\": True,\n",
" \"workers\": 1,\n",
" }\n",
" )\n",
" return datasets_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from spf.scripts.train_single_point import load_dataloaders\n",
"\n",
"from tqdm import tqdm\n",
"\n",
"config_fn = \"/home/mouse9911/gits/spf/nov2_checkpoints/nov2_small_paired_checkpoints_inputdo0p3/config.yml\"\n",
"checkpoint_fn = \"/home/mouse9911/gits/spf/nov2_checkpoints/nov2_small_paired_checkpoints_inputdo0p3/best.pth\"\n",
"ds_fn = \"/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_10_30_58_nRX2_bounce_spacing0p05075.zarr\"\n",
"\n",
"# load model\n",
"model, config = load_model_and_config_from_config_fn_and_checkpoint(\n",
" config_fn=config_fn, checkpoint_fn=checkpoint_fn\n",
")\n",
"\n",
"# load datasets config\n",
"datasets_config = convert_datasets_config_to_inference(\n",
" config[\"datasets\"],\n",
" ds_fn=ds_fn,\n",
")\n",
"\n",
"# load dataloader\n",
"optim_config = {\"device\": \"cuda\", \"dtype\": torch.float32}\n",
"global_config = {\"nthetas\": 65, \"n_radios\": 2, \"seed\": 0, \"beamformer_input\": True}\n",
"train_dataloader, val_dataloader = load_dataloaders(\n",
" datasets_config, optim_config, config[\"global\"], step=0, epoch=0\n",
")\n",
"\n",
"# run inference\n",
"model.eval()\n",
"for _, val_batch_data in enumerate(tqdm(val_dataloader, leave=False)):\n",
" val_batch_data = val_batch_data.to(config[\"optim\"][\"device\"])\n",
" output = model(val_batch_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "spf",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
89 changes: 61 additions & 28 deletions spf/scripts/train_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
SinglePointWithBeamformer,
TrajPairedMultiPointWithBeamformer,
)
from spf.rf import torch_pi_norm, torch_reduce_theta_to_positive_y
from spf.rf import torch_pi_norm
from spf.utils import StatefulBatchsampler


Expand Down Expand Up @@ -64,7 +64,8 @@ def load_dataloaders(datasets_config, optim_config, global_config, step=0, epoch
# glob.glob('./[0-9].*')

train_dataset_filenames = expand_wildcards_and_join(datasets_config["train_paths"])
random.shuffle(train_dataset_filenames)
if datasets_config["shuffle"]:
random.shuffle(train_dataset_filenames)

if datasets_config.get("train_on_val", False):
val_paths = train_dataset_filenames
Expand Down Expand Up @@ -140,12 +141,15 @@ def load_dataloaders(datasets_config, optim_config, global_config, step=0, epoch
train_idxs = range(len(train_ds))
val_idxs = list(range(len(val_ds)))

random.shuffle(val_idxs)
val_idxs = val_idxs[
: max(
1, int(len(val_idxs) * datasets_config.get("val_subsample_fraction", 1.0))
)
]
if datasets_config["shuffle"]:
random.shuffle(val_idxs)
if not datasets_config.get("train_on_val", False):
val_idxs = val_idxs[
: max(
1,
int(len(val_idxs) * datasets_config.get("val_subsample_fraction", 1.0)),
)
]

train_ds = torch.utils.data.Subset(train_ds, train_idxs)
val_ds = torch.utils.data.Subset(val_ds, val_idxs)
Expand Down Expand Up @@ -272,7 +276,9 @@ def save_model(
yaml.dump(running_config, outfile)


def load_checkpoint(checkpoint_fn, config, model, optimizer, scheduler):
def load_checkpoint(
checkpoint_fn, config, model, optimizer, scheduler, force_load=False
):
logging.info(f"Loading checkpoint {checkpoint_fn}")
checkpoint = torch.load(checkpoint_fn, map_location=torch.device("cpu"))

Expand All @@ -293,26 +299,37 @@ def load_checkpoint(checkpoint_fn, config, model, optimizer, scheduler):
# scheduler_being_loaded.load_state_dict(checkpoint["scheduler_state_dict"])

# check if we loading a single network
if config["model"].get("load_single", False):
logging.info("Loading single_radio_net only")
model.single_radio_net.load_state_dict(checkpoint["model_state_dict"])
for param in model.single_radio_net.parameters():
param.requires_grad = False
# model.single_radio_net = FrozenModule(model_being_loaded)
return (model, optimizer, scheduler, 0, 0) # epoch # step
elif config["model"].get("load_paired", False):
# check if we loading a paired network
logging.info("Loading paired_radio net only")
model.multi_radio_net.load_state_dict(checkpoint["model_state_dict"])
for param in model.multi_radio_net.parameters():
param.requires_grad = False
# breakpoint()
return (model, optimizer, scheduler, 0, 0) # epoch # step
if not force_load:
if config["model"].get("load_single", False):
logging.info("Loading single_radio_net only")
model.single_radio_net.load_state_dict(checkpoint["model_state_dict"])
for param in model.single_radio_net.parameters():
param.requires_grad = False
# model.single_radio_net = FrozenModule(model_being_loaded)
return (model, optimizer, scheduler, 0, 0) # epoch # step
elif config["model"].get("load_paired", False):
# check if we loading a paired network
logging.info("Loading paired_radio net only")
model.multi_radio_net.load_state_dict(checkpoint["model_state_dict"])
for param in model.multi_radio_net.parameters():
param.requires_grad = False
# breakpoint()
return (model, optimizer, scheduler, 0, 0) # epoch # step

# else
logging.debug("loading_checkpoint: checkpoint state dict")
for key, v in checkpoint["model_state_dict"].items():
logging.debug(f"\t{key}\t{v.shape}")

logging.debug("loading_checkpoint: model state dict")
for key, v in model.state_dict().items():
logging.debug(f"\t{key}\t{v.shape}")

model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if scheduler is not None:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

return (
model,
Expand Down Expand Up @@ -633,7 +650,7 @@ def train_single_point(args):
if args.output is None:
args.output = datetime.datetime.now().strftime("spf-run-%Y-%m-%d_%H-%M-%S")
try:
os.mkdir(args.output)
os.makedirs(args.output)
except FileExistsError:
pass

Expand Down Expand Up @@ -677,7 +694,16 @@ def train_single_point(args):
step = 0
start_epoch = 0

if "checkpoint" in config["optim"]:
if args.resume_from is not None:
m, optimizer, scheduler, start_epoch, step = load_checkpoint(
checkpoint_fn=args.resume_from,
config=config,
model=m,
optimizer=optimizer,
scheduler=scheduler,
force_load=True,
)
elif "checkpoint" in config["optim"]:
m, optimizer, scheduler, start_epoch, step = load_checkpoint(
checkpoint_fn=config["optim"]["checkpoint"],
config=config,
Expand Down Expand Up @@ -893,6 +919,13 @@ def get_parser_filter():
help="config file",
required=True,
)
parser.add_argument(
"--resume-from",
type=str,
help="resume from checkpoint file",
default=None,
required=False,
)
parser.add_argument(
"-o",
"--output",
Expand Down

0 comments on commit 6d1501c

Please sign in to comment.