diff --git a/spf/filters/filters.py b/spf/filters/filters.py index 86527e2..b7090bd 100644 --- a/spf/filters/filters.py +++ b/spf/filters/filters.py @@ -262,6 +262,8 @@ def single_radio_mse_theta_metrics(trajectory, ground_truth_thetas): ground_truth_reduced_theta = torch.as_tensor( reduce_theta_to_positive_y(ground_truth_thetas) ) + assert ground_truth_reduced_theta.shape[0] >= pred_theta.shape[0] + ground_truth_reduced_theta = ground_truth_reduced_theta[: pred_theta.shape[0]] assert pred_theta.ndim == 1 and ground_truth_reduced_theta.ndim == 1 return { "mse_single_radio_theta": ( @@ -298,6 +300,7 @@ def trajectory( return_particles=False, debug=False, seed=0, + steps=None, ): self.generator = torch.Generator() self.generator.manual_seed(seed) @@ -306,7 +309,7 @@ def trajectory( ) self.weights = torch.ones((N,), dtype=torch.float64) / N trajectory = [] - for idx in range(len(self.ds)): + for idx in range(len(self.ds) if steps == None else min(steps, len(self.ds))): self.predict( dt=1.0, noise_std=noise_std, diff --git a/spf/model_training_and_inference/models/single_point_networks_inference.py b/spf/model_training_and_inference/models/single_point_networks_inference.py index bc7f059..68c62fe 100644 --- a/spf/model_training_and_inference/models/single_point_networks_inference.py +++ b/spf/model_training_and_inference/models/single_point_networks_inference.py @@ -1,13 +1,25 @@ +import hashlib +import os + +import numpy as np +import torch +from tqdm import tqdm + from spf.scripts.train_single_point import ( load_checkpoint, load_config_from_fn, + load_dataloaders, load_model, ) -def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn): +def load_model_and_config_from_config_fn_and_checkpoint( + config_fn, checkpoint_fn, device=None +): config = load_config_from_fn(config_fn) config["optim"]["checkpoint"] = checkpoint_fn + if device is not None: + config["optim"]["device"] = device m = load_model(config["model"], config["global"]).to(config["optim"]["device"]) m, _, _, _, _ = load_checkpoint( checkpoint_fn=config["optim"]["checkpoint"], @@ -20,14 +32,15 @@ def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn return m, config -def convert_datasets_config_to_inference(datasets_config, ds_fn): +def convert_datasets_config_to_inference( + datasets_config, ds_fn, precompute_cache, batch_size=1, workers=1 +): datasets_config = datasets_config.copy() datasets_config.update( { - "batch_size": 1, + "batch_size": batch_size, "flip": False, "double_flip": False, - "precompute_cache": "/home/mouse9911/precompute_cache_chunk16_sept", "shuffle": False, "skip_qc": True, "snapshots_adjacent_stride": 1, @@ -37,7 +50,110 @@ def convert_datasets_config_to_inference(datasets_config, ds_fn): "snapshots_stride": 1, "train_paths": [ds_fn], "train_on_val": True, - "workers": 1, + "workers": workers, } ) + if precompute_cache is not None: + datasets_config.update({"precompute_cache": precompute_cache}) return datasets_config + + +def get_md5_of_file(fn, cache_md5=True): + if os.path.exists(fn + ".md5"): + return open(fn + ".md5", "r").readlines()[0].strip() + hash_md5 = hashlib.md5() + with open(fn, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + + md5 = hash_md5.hexdigest() + if cache_md5: + with open(fn + ".md5", "w") as f: + f.write(md5) + return md5 + + +def get_inference_on_ds( + ds_fn, + config_fn, + checkpoint_fn, + inference_cache=None, + device="cpu", + batch_size=128, + workers=8, + precompute_cache=None, +): + if inference_cache is None: + return run_inference_on_ds( + ds_fn=ds_fn, + config_fn=config_fn, + checkpoint_fn=checkpoint_fn, + device=device, + batch_size=batch_size, + workers=workers, + precompute_cache=precompute_cache, + ) + config_checksum = get_md5_of_file(config_fn) + checkpoint_checksum = get_md5_of_file(checkpoint_fn) + ds_basename = os.path.basename(ds_fn) + inference_cache_fn = ( + f"{inference_cache}/{ds_basename}/{checkpoint_checksum}/{config_checksum}.npz" + ) + if os.path.exists(inference_cache_fn): + return np.load(inference_cache_fn) + # run inference + os.makedirs(os.path.dirname(inference_cache_fn), exist_ok=True) + results = run_inference_on_ds( + ds_fn=ds_fn, + config_fn=config_fn, + checkpoint_fn=checkpoint_fn, + device=device, + batch_size=batch_size, + workers=workers, + precompute_cache=precompute_cache, + ) + results = {key: value.numpy() for key, value in results.items()} + np.savez_compressed(inference_cache_fn + ".tmp", **results) + os.rename(inference_cache_fn + ".tmp.npz", inference_cache_fn) + return results + + +def run_inference_on_ds( + ds_fn, config_fn, checkpoint_fn, device, batch_size, workers, precompute_cache +): + # load model and model config + model, config = load_model_and_config_from_config_fn_and_checkpoint( + config_fn=config_fn, checkpoint_fn=checkpoint_fn, device=device + ) + + # prepare inference configs + optim_config = {"device": device, "dtype": torch.float32} + datasets_config = convert_datasets_config_to_inference( + config["datasets"], + ds_fn=ds_fn, + batch_size=batch_size, + workers=workers, + precompute_cache=precompute_cache, + ) + + _, val_dataloader = load_dataloaders( + datasets_config, optim_config, config["global"], step=0, epoch=0 + ) + model.eval() + outputs = [] + with torch.no_grad(): + for _, val_batch_data in enumerate(val_dataloader): + val_batch_data = val_batch_data.to(optim_config["device"]) + outputs.append(model(val_batch_data)) + results = {"single": torch.vstack([output["single"] for output in outputs]).cpu()} + sessions_times_radios, snapshots, ntheta = results["single"].shape + sessions = sessions_times_radios // 2 + radios = 2 + results["single"] = results["single"].reshape(sessions, radios, snapshots, ntheta) + + if "paired" in outputs[0]: + results["paired"] = torch.vstack([output["paired"] for output in outputs]).cpu() + results["paired"] = results["paired"].reshape( + sessions, radios, snapshots, ntheta + ) + return results diff --git a/spf/scripts/create_empirical_p_dist.py b/spf/scripts/create_empirical_p_dist.py index f2f9b81..47b99c7 100644 --- a/spf/scripts/create_empirical_p_dist.py +++ b/spf/scripts/create_empirical_p_dist.py @@ -4,6 +4,7 @@ import numpy as np import torch +import tqdm from matplotlib import pyplot as plt from spf.dataset.spf_dataset import v5spfdataset diff --git a/spf/scripts/create_inference_cache.py b/spf/scripts/create_inference_cache.py new file mode 100644 index 0000000..c005475 --- /dev/null +++ b/spf/scripts/create_inference_cache.py @@ -0,0 +1,78 @@ +import argparse +from functools import partial +from multiprocessing import Pool + +import tqdm + +from spf.model_training_and_inference.models.single_point_networks_inference import ( + get_inference_on_ds, +) + +if __name__ == "__main__": + + def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-fn", + type=str, + required=True, + ) + parser.add_argument( + "--checkpoint-fn", + type=str, + required=True, + ) + parser.add_argument( + "--inference-cache", + type=str, + required=True, + ) + parser.add_argument("--parallel", type=int, required=False, default=16) + + parser.add_argument("--precompute-cache", type=str, required=True) + parser.add_argument("--workers", type=int, required=False, default=0) + parser.add_argument("--device", type=str, required=False, default="cuda") + parser.add_argument( + "-d", + "--datasets", + type=str, + help="dataset prefixes", + nargs="+", + required=True, + ) + parser.add_argument( + "--debug", + action=argparse.BooleanOptionalAction, + default=False, + ) + return parser + + parser = get_parser() + args = parser.parse_args() + + run_fn = partial( + get_inference_on_ds, + config_fn=args.config_fn, + checkpoint_fn=args.checkpoint_fn, + device=args.device, + inference_cache=args.inference_cache, + batch_size=64, + workers=0, + precompute_cache=args.precompute_cache, + ) + + if args.debug: + results = list( + tqdm.tqdm( + map(run_fn, args.datasets), + total=len(args.datasets), + ) + ) + else: + with Pool(args.parallel) as pool: # cpu_count()) # cpu_count() // 4) + results = list( + tqdm.tqdm( + pool.imap(run_fn, args.datasets), + total=len(args.datasets), + ) + ) diff --git a/tests/test_new_models.py b/tests/test_new_models.py index 63bc468..985ff12 100644 --- a/tests/test_new_models.py +++ b/tests/test_new_models.py @@ -11,6 +11,7 @@ from spf.dataset.spf_dataset import v5_collate_keys_fast, v5spfdataset from spf.model_training_and_inference.models.single_point_networks_inference import ( convert_datasets_config_to_inference, + get_inference_on_ds, load_model_and_config_from_config_fn_and_checkpoint, ) from spf.scripts.train_single_point import ( @@ -135,13 +136,17 @@ def dataloader_inference(model, global_config, datasets_config, optim_config): for _, val_batch_data in enumerate(tqdm(val_dataloader, leave=False)): val_batch_data = val_batch_data.to(optim_config["device"]) outputs.append(model(val_batch_data)) + results = {"single": torch.vstack([output["single"] for output in outputs])} + + sessions_times_radios, snapshots, ntheta = results["single"].shape + sessions = sessions_times_radios // 2 + radios = 2 + results["single"] = results["single"].reshape(sessions, radios, snapshots, ntheta) - results = { - "single": torch.vstack([output["single"].unsqueeze(0) for output in outputs]) - } if "paired" in outputs[0]: - results["paired"] = torch.vstack( - [output["paired"].unsqueeze(0) for output in outputs] + results["paired"] = torch.vstack([output["paired"] for output in outputs]) + results["paired"] = results["paired"].reshape( + sessions, radios, snapshots, ntheta ) return results @@ -166,12 +171,17 @@ def single_example_inference(model, global_config, datasets_config, optim_config optim_config["device"] ) outputs.append(model(single_example)) - results = { - "single": torch.vstack([output["single"].unsqueeze(0) for output in outputs]) - } + results = {"single": torch.vstack([output["single"] for output in outputs])} + + sessions_times_radios, snapshots, ntheta = results["single"].shape + sessions = sessions_times_radios // 2 + radios = 2 + results["single"] = results["single"].reshape(sessions, radios, snapshots, ntheta) + if "paired" in outputs[0]: - results["paired"] = torch.vstack( - [output["paired"].unsqueeze(0) for output in outputs] + results["paired"] = torch.vstack([output["paired"] for output in outputs]) + results["paired"] = results["paired"].reshape( + sessions, radios, snapshots, ntheta ) return results @@ -180,7 +190,7 @@ def test_inference_single_checkpoint( single_net_checkpoint, perfect_circle_dataset_n7_with_empirical ): single_checkpoints_dir = single_net_checkpoint - _, _, zarr_fn = perfect_circle_dataset_n7_with_empirical + precompute_cache, _, zarr_fn = perfect_circle_dataset_n7_with_empirical ds_fn = f"{zarr_fn}.zarr" config_fn = f"{single_checkpoints_dir}/config.yml" @@ -193,25 +203,63 @@ def test_inference_single_checkpoint( # prepare inference configs optim_config = {"device": "cpu", "dtype": torch.float32} - global_config = {"nthetas": 65, "n_radios": 2, "seed": 0, "beamformer_input": True} datasets_config = convert_datasets_config_to_inference( - config["datasets"], - ds_fn=ds_fn, + config["datasets"], ds_fn=ds_fn, batch_size=3, precompute_cache=precompute_cache ) # inference using dataloader dataloader_results = dataloader_inference( - model, global_config, datasets_config, optim_config + model, config["global"], datasets_config, optim_config ) # run inference one at a time single_example_results = single_example_inference( - model, global_config, datasets_config, optim_config + model, config["global"], datasets_config, optim_config ) assert dataloader_results["single"].isclose(single_example_results["single"]).all() +def test_inference_single_checkpoint_against_ds_inference( + single_net_checkpoint, perfect_circle_dataset_n7_with_empirical +): + single_checkpoints_dir = single_net_checkpoint + precompute_cache, _, zarr_fn = perfect_circle_dataset_n7_with_empirical + + ds_fn = f"{zarr_fn}.zarr" + config_fn = f"{single_checkpoints_dir}/config.yml" + checkpoint_fn = f"{single_checkpoints_dir}/best.pth" + + # load model and model config + model, config = load_model_and_config_from_config_fn_and_checkpoint( + config_fn=config_fn, checkpoint_fn=checkpoint_fn + ) + + # prepare inference configs + optim_config = {"device": "cpu", "dtype": torch.float32} + datasets_config = convert_datasets_config_to_inference( + config["datasets"], ds_fn=ds_fn, batch_size=3, precompute_cache=precompute_cache + ) + + # run inference one at a time + single_example_results = single_example_inference( + model, config["global"], datasets_config, optim_config + ) + + results = get_inference_on_ds( + ds_fn, + config_fn, + checkpoint_fn, + inference_cache=None, + device="cpu", + batch_size=4, + workers=0, + precompute_cache=None, + ) + + assert results["single"].isclose(single_example_results["single"]).all() + + def test_inference_paired_checkpoint( single_net_checkpoint, paired_net_checkpoint_using_single_checkpoint, @@ -219,7 +267,7 @@ def test_inference_paired_checkpoint( ): # get single checkpoint results single_checkpoints_dir = single_net_checkpoint - _, _, zarr_fn = perfect_circle_dataset_n7_with_empirical + precompute_cache, _, zarr_fn = perfect_circle_dataset_n7_with_empirical ds_fn = f"{zarr_fn}.zarr" single_config_fn = f"{single_checkpoints_dir}/config.yml" @@ -232,15 +280,16 @@ def test_inference_paired_checkpoint( # prepare inference configs optim_config = {"device": "cpu", "dtype": torch.float32} - global_config = {"nthetas": 65, "n_radios": 2, "seed": 0, "beamformer_input": True} single_datasets_config = convert_datasets_config_to_inference( single_config["datasets"], ds_fn=ds_fn, + batch_size=3, + precompute_cache=precompute_cache, ) # inference using dataloader dataloader_single_results = dataloader_inference( - single_model, global_config, single_datasets_config, optim_config + single_model, single_config["global"], single_datasets_config, optim_config ) # get paired checkpoint results @@ -258,16 +307,17 @@ def test_inference_paired_checkpoint( paired_datasets_config = convert_datasets_config_to_inference( paired_config["datasets"], ds_fn=ds_fn, + precompute_cache=precompute_cache, ) # inference using dataloader dataloader_paired_results = dataloader_inference( - paired_model, global_config, paired_datasets_config, optim_config + paired_model, paired_config["global"], paired_datasets_config, optim_config ) # run inference one at a time single_example_paired_results = single_example_inference( - paired_model, global_config, paired_datasets_config, optim_config + paired_model, paired_config["global"], paired_datasets_config, optim_config ) assert (