From 6d1501c9bf6524b1229626c5ac666e44187a2c60 Mon Sep 17 00:00:00 2001 From: misko Date: Sun, 3 Nov 2024 00:54:08 +0000 Subject: [PATCH] fix up network naming; add example to load network and run inference in notebook --- .../models/beamsegnet.py | 14 +- .../models/single_point_networks.py | 12 +- spf/notebooks/load_model_checkpoint.ipynb | 123 ++++++++++++++++++ spf/scripts/train_single_point.py | 89 +++++++++---- 4 files changed, 195 insertions(+), 43 deletions(-) create mode 100644 spf/notebooks/load_model_checkpoint.ipynb diff --git a/spf/model_training_and_inference/models/beamsegnet.py b/spf/model_training_and_inference/models/beamsegnet.py index ceb659a..b0ce846 100644 --- a/spf/model_training_and_inference/models/beamsegnet.py +++ b/spf/model_training_and_inference/models/beamsegnet.py @@ -1,19 +1,11 @@ -import math from collections import OrderedDict from math import sqrt -import numpy as np import torch import torch.nn as nn -from torch.nn.functional import sigmoid from spf.dataset.spf_dataset import v5_thetas_to_targets -from spf.rf import ( - pi_norm, # torch_circular_mean_weighted, - reduce_theta_to_positive_y, - torch_circular_mean, - torch_pi_norm, -) +from spf.rf import reduce_theta_to_positive_y, torch_circular_mean, torch_pi_norm class ConvNet(nn.Module): @@ -315,10 +307,10 @@ def __init__( net_layout += [nn.Linear(hidden, outputs)] else: raise ValueError(f"Norm not implemented {norm}") - self.net = nn.Sequential(*net_layout) + self.ffnn_internal_net = nn.Sequential(*net_layout) def forward(self, x): - return self.net(x) + return self.ffnn_internal_net(x) class HalfPiEncoding(nn.Module): diff --git a/spf/model_training_and_inference/models/single_point_networks.py b/spf/model_training_and_inference/models/single_point_networks.py index 3db4269..4db4879 100644 --- a/spf/model_training_and_inference/models/single_point_networks.py +++ b/spf/model_training_and_inference/models/single_point_networks.py @@ -75,7 +75,7 @@ class SinglePointWithBeamformer(nn.Module): def __init__(self, model_config, global_config): super().__init__() self.prepare_input = PrepareInput(model_config, global_config) - self.net = FFNN( + self.single_point_with_beamformer_ffnn = FFNN( inputs=self.prepare_input.inputs, depth=model_config["depth"], # 4 hidden=model_config["hidden"], # 128 @@ -92,7 +92,11 @@ def forward(self, batch): # first dim odd / even is the radios return { "single": torch.nn.functional.normalize( - self.net(self.prepare_input.prepare_input(batch)).abs(), dim=2, p=1 + self.single_point_with_beamformer_ffnn( + self.prepare_input.prepare_input(batch) + ).abs(), + dim=2, + p=1, ) } @@ -104,7 +108,7 @@ def __init__(self, model_config, global_config): model_config["single"], global_config ) self.detach = model_config.get("detach", True) - self.net = FFNN( + self.paired_single_point_with_beamformer_ffnn = FFNN( inputs=global_config["nthetas"] * 2, depth=model_config["depth"], # 4 hidden=model_config["hidden"], # 128 @@ -123,7 +127,7 @@ def forward(self, batch): single_radio_estimates, self.detach ) - x = self.net( + x = self.paired_single_point_with_beamformer_ffnn( torch.concatenate( [single_radio_estimates_input[::2], single_radio_estimates_input[1::2]], dim=2, diff --git a/spf/notebooks/load_model_checkpoint.ipynb b/spf/notebooks/load_model_checkpoint.ipynb new file mode 100644 index 0000000..ffadd89 --- /dev/null +++ b/spf/notebooks/load_model_checkpoint.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from spf.scripts.train_single_point import (\n", + " load_checkpoint,\n", + " load_config_from_fn,\n", + " load_model,\n", + ")\n", + "\n", + "\n", + "def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn):\n", + " config = load_config_from_fn(config_fn)\n", + " config[\"optim\"][\"checkpoint\"] = checkpoint_fn\n", + " m = load_model(config[\"model\"], config[\"global\"]).to(config[\"optim\"][\"device\"])\n", + " m, _, _, _, _ = load_checkpoint(\n", + " checkpoint_fn=config[\"optim\"][\"checkpoint\"],\n", + " config=config,\n", + " model=m,\n", + " optimizer=None,\n", + " scheduler=None,\n", + " force_load=True,\n", + " )\n", + " return m, config\n", + "\n", + "\n", + "def convert_datasets_config_to_inference(datasets_config, ds_fn):\n", + " datasets_config = datasets_config.copy()\n", + " datasets_config.update(\n", + " {\n", + " \"batch_size\": 1,\n", + " \"flip\": False,\n", + " \"double_flip\": False,\n", + " \"precompute_cache\": \"/home/mouse9911/precompute_cache_chunk16_sept\",\n", + " \"shuffle\": False,\n", + " \"skip_qc\": True,\n", + " \"snapshots_adjacent_stride\": 1,\n", + " \"train_snapshots_per_session\": 1,\n", + " \"val_snapshots_per_session\": 1,\n", + " \"random_snapshot_size\": False,\n", + " \"snapshots_stride\": 1,\n", + " \"train_paths\": [ds_fn],\n", + " \"train_on_val\": True,\n", + " \"workers\": 1,\n", + " }\n", + " )\n", + " return datasets_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from spf.scripts.train_single_point import load_dataloaders\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "config_fn = \"/home/mouse9911/gits/spf/nov2_checkpoints/nov2_small_paired_checkpoints_inputdo0p3/config.yml\"\n", + "checkpoint_fn = \"/home/mouse9911/gits/spf/nov2_checkpoints/nov2_small_paired_checkpoints_inputdo0p3/best.pth\"\n", + "ds_fn = \"/mnt/4tb_ssd/nosig_data/wallarrayv3_2024_08_21_10_30_58_nRX2_bounce_spacing0p05075.zarr\"\n", + "\n", + "# load model\n", + "model, config = load_model_and_config_from_config_fn_and_checkpoint(\n", + " config_fn=config_fn, checkpoint_fn=checkpoint_fn\n", + ")\n", + "\n", + "# load datasets config\n", + "datasets_config = convert_datasets_config_to_inference(\n", + " config[\"datasets\"],\n", + " ds_fn=ds_fn,\n", + ")\n", + "\n", + "# load dataloader\n", + "optim_config = {\"device\": \"cuda\", \"dtype\": torch.float32}\n", + "global_config = {\"nthetas\": 65, \"n_radios\": 2, \"seed\": 0, \"beamformer_input\": True}\n", + "train_dataloader, val_dataloader = load_dataloaders(\n", + " datasets_config, optim_config, config[\"global\"], step=0, epoch=0\n", + ")\n", + "\n", + "# run inference\n", + "model.eval()\n", + "for _, val_batch_data in enumerate(tqdm(val_dataloader, leave=False)):\n", + " val_batch_data = val_batch_data.to(config[\"optim\"][\"device\"])\n", + " output = model(val_batch_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "spf", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/spf/scripts/train_single_point.py b/spf/scripts/train_single_point.py index d450d0f..c502539 100644 --- a/spf/scripts/train_single_point.py +++ b/spf/scripts/train_single_point.py @@ -25,7 +25,7 @@ SinglePointWithBeamformer, TrajPairedMultiPointWithBeamformer, ) -from spf.rf import torch_pi_norm, torch_reduce_theta_to_positive_y +from spf.rf import torch_pi_norm from spf.utils import StatefulBatchsampler @@ -64,7 +64,8 @@ def load_dataloaders(datasets_config, optim_config, global_config, step=0, epoch # glob.glob('./[0-9].*') train_dataset_filenames = expand_wildcards_and_join(datasets_config["train_paths"]) - random.shuffle(train_dataset_filenames) + if datasets_config["shuffle"]: + random.shuffle(train_dataset_filenames) if datasets_config.get("train_on_val", False): val_paths = train_dataset_filenames @@ -140,12 +141,15 @@ def load_dataloaders(datasets_config, optim_config, global_config, step=0, epoch train_idxs = range(len(train_ds)) val_idxs = list(range(len(val_ds))) - random.shuffle(val_idxs) - val_idxs = val_idxs[ - : max( - 1, int(len(val_idxs) * datasets_config.get("val_subsample_fraction", 1.0)) - ) - ] + if datasets_config["shuffle"]: + random.shuffle(val_idxs) + if not datasets_config.get("train_on_val", False): + val_idxs = val_idxs[ + : max( + 1, + int(len(val_idxs) * datasets_config.get("val_subsample_fraction", 1.0)), + ) + ] train_ds = torch.utils.data.Subset(train_ds, train_idxs) val_ds = torch.utils.data.Subset(val_ds, val_idxs) @@ -272,7 +276,9 @@ def save_model( yaml.dump(running_config, outfile) -def load_checkpoint(checkpoint_fn, config, model, optimizer, scheduler): +def load_checkpoint( + checkpoint_fn, config, model, optimizer, scheduler, force_load=False +): logging.info(f"Loading checkpoint {checkpoint_fn}") checkpoint = torch.load(checkpoint_fn, map_location=torch.device("cpu")) @@ -293,26 +299,37 @@ def load_checkpoint(checkpoint_fn, config, model, optimizer, scheduler): # scheduler_being_loaded.load_state_dict(checkpoint["scheduler_state_dict"]) # check if we loading a single network - if config["model"].get("load_single", False): - logging.info("Loading single_radio_net only") - model.single_radio_net.load_state_dict(checkpoint["model_state_dict"]) - for param in model.single_radio_net.parameters(): - param.requires_grad = False - # model.single_radio_net = FrozenModule(model_being_loaded) - return (model, optimizer, scheduler, 0, 0) # epoch # step - elif config["model"].get("load_paired", False): - # check if we loading a paired network - logging.info("Loading paired_radio net only") - model.multi_radio_net.load_state_dict(checkpoint["model_state_dict"]) - for param in model.multi_radio_net.parameters(): - param.requires_grad = False - # breakpoint() - return (model, optimizer, scheduler, 0, 0) # epoch # step + if not force_load: + if config["model"].get("load_single", False): + logging.info("Loading single_radio_net only") + model.single_radio_net.load_state_dict(checkpoint["model_state_dict"]) + for param in model.single_radio_net.parameters(): + param.requires_grad = False + # model.single_radio_net = FrozenModule(model_being_loaded) + return (model, optimizer, scheduler, 0, 0) # epoch # step + elif config["model"].get("load_paired", False): + # check if we loading a paired network + logging.info("Loading paired_radio net only") + model.multi_radio_net.load_state_dict(checkpoint["model_state_dict"]) + for param in model.multi_radio_net.parameters(): + param.requires_grad = False + # breakpoint() + return (model, optimizer, scheduler, 0, 0) # epoch # step # else + logging.debug("loading_checkpoint: checkpoint state dict") + for key, v in checkpoint["model_state_dict"].items(): + logging.debug(f"\t{key}\t{v.shape}") + + logging.debug("loading_checkpoint: model state dict") + for key, v in model.state_dict().items(): + logging.debug(f"\t{key}\t{v.shape}") + model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if scheduler is not None: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) return ( model, @@ -633,7 +650,7 @@ def train_single_point(args): if args.output is None: args.output = datetime.datetime.now().strftime("spf-run-%Y-%m-%d_%H-%M-%S") try: - os.mkdir(args.output) + os.makedirs(args.output) except FileExistsError: pass @@ -677,7 +694,16 @@ def train_single_point(args): step = 0 start_epoch = 0 - if "checkpoint" in config["optim"]: + if args.resume_from is not None: + m, optimizer, scheduler, start_epoch, step = load_checkpoint( + checkpoint_fn=args.resume_from, + config=config, + model=m, + optimizer=optimizer, + scheduler=scheduler, + force_load=True, + ) + elif "checkpoint" in config["optim"]: m, optimizer, scheduler, start_epoch, step = load_checkpoint( checkpoint_fn=config["optim"]["checkpoint"], config=config, @@ -893,6 +919,13 @@ def get_parser_filter(): help="config file", required=True, ) + parser.add_argument( + "--resume-from", + type=str, + help="resume from checkpoint file", + default=None, + required=False, + ) parser.add_argument( "-o", "--output",