Skip to content

Commit

Permalink
create an inference cache to be able to run repeated EKF / PF faster
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Nov 3, 2024
1 parent ed648fb commit c1f047c
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 27 deletions.
5 changes: 4 additions & 1 deletion spf/filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ def single_radio_mse_theta_metrics(trajectory, ground_truth_thetas):
ground_truth_reduced_theta = torch.as_tensor(
reduce_theta_to_positive_y(ground_truth_thetas)
)
assert ground_truth_reduced_theta.shape[0] >= pred_theta.shape[0]
ground_truth_reduced_theta = ground_truth_reduced_theta[: pred_theta.shape[0]]
assert pred_theta.ndim == 1 and ground_truth_reduced_theta.ndim == 1
return {
"mse_single_radio_theta": (
Expand Down Expand Up @@ -298,6 +300,7 @@ def trajectory(
return_particles=False,
debug=False,
seed=0,
steps=None,
):
self.generator = torch.Generator()
self.generator.manual_seed(seed)
Expand All @@ -306,7 +309,7 @@ def trajectory(
)
self.weights = torch.ones((N,), dtype=torch.float64) / N
trajectory = []
for idx in range(len(self.ds)):
for idx in range(len(self.ds) if steps == None else min(steps, len(self.ds))):
self.predict(
dt=1.0,
noise_std=noise_std,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import hashlib
import os

import numpy as np
import torch
from tqdm import tqdm

from spf.scripts.train_single_point import (
load_checkpoint,
load_config_from_fn,
load_dataloaders,
load_model,
)


def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn):
def load_model_and_config_from_config_fn_and_checkpoint(
config_fn, checkpoint_fn, device=None
):
config = load_config_from_fn(config_fn)
config["optim"]["checkpoint"] = checkpoint_fn
if device is not None:
config["optim"]["device"] = device
m = load_model(config["model"], config["global"]).to(config["optim"]["device"])
m, _, _, _, _ = load_checkpoint(
checkpoint_fn=config["optim"]["checkpoint"],
Expand All @@ -20,14 +32,15 @@ def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn
return m, config


def convert_datasets_config_to_inference(datasets_config, ds_fn):
def convert_datasets_config_to_inference(
datasets_config, ds_fn, precompute_cache, batch_size=1, workers=1
):
datasets_config = datasets_config.copy()
datasets_config.update(
{
"batch_size": 1,
"batch_size": batch_size,
"flip": False,
"double_flip": False,
"precompute_cache": "/home/mouse9911/precompute_cache_chunk16_sept",
"shuffle": False,
"skip_qc": True,
"snapshots_adjacent_stride": 1,
Expand All @@ -37,7 +50,110 @@ def convert_datasets_config_to_inference(datasets_config, ds_fn):
"snapshots_stride": 1,
"train_paths": [ds_fn],
"train_on_val": True,
"workers": 1,
"workers": workers,
}
)
if precompute_cache is not None:
datasets_config.update({"precompute_cache": precompute_cache})
return datasets_config


def get_md5_of_file(fn, cache_md5=True):
if os.path.exists(fn + ".md5"):
return open(fn + ".md5", "r").readlines()[0].strip()
hash_md5 = hashlib.md5()
with open(fn, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)

md5 = hash_md5.hexdigest()
if cache_md5:
with open(fn + ".md5", "w") as f:
f.write(md5)
return md5


def get_inference_on_ds(
ds_fn,
config_fn,
checkpoint_fn,
inference_cache=None,
device="cpu",
batch_size=128,
workers=8,
precompute_cache=None,
):
if inference_cache is None:
return run_inference_on_ds(
ds_fn=ds_fn,
config_fn=config_fn,
checkpoint_fn=checkpoint_fn,
device=device,
batch_size=batch_size,
workers=workers,
precompute_cache=precompute_cache,
)
config_checksum = get_md5_of_file(config_fn)
checkpoint_checksum = get_md5_of_file(checkpoint_fn)
ds_basename = os.path.basename(ds_fn)
inference_cache_fn = (
f"{inference_cache}/{ds_basename}/{checkpoint_checksum}/{config_checksum}.npz"
)
if os.path.exists(inference_cache_fn):
return np.load(inference_cache_fn)
# run inference
os.makedirs(os.path.dirname(inference_cache_fn), exist_ok=True)
results = run_inference_on_ds(
ds_fn=ds_fn,
config_fn=config_fn,
checkpoint_fn=checkpoint_fn,
device=device,
batch_size=batch_size,
workers=workers,
precompute_cache=precompute_cache,
)
results = {key: value.numpy() for key, value in results.items()}
np.savez_compressed(inference_cache_fn + ".tmp", **results)
os.rename(inference_cache_fn + ".tmp.npz", inference_cache_fn)
return results


def run_inference_on_ds(
ds_fn, config_fn, checkpoint_fn, device, batch_size, workers, precompute_cache
):
# load model and model config
model, config = load_model_and_config_from_config_fn_and_checkpoint(
config_fn=config_fn, checkpoint_fn=checkpoint_fn, device=device
)

# prepare inference configs
optim_config = {"device": device, "dtype": torch.float32}
datasets_config = convert_datasets_config_to_inference(
config["datasets"],
ds_fn=ds_fn,
batch_size=batch_size,
workers=workers,
precompute_cache=precompute_cache,
)

_, val_dataloader = load_dataloaders(
datasets_config, optim_config, config["global"], step=0, epoch=0
)
model.eval()
outputs = []
with torch.no_grad():
for _, val_batch_data in enumerate(val_dataloader):
val_batch_data = val_batch_data.to(optim_config["device"])
outputs.append(model(val_batch_data))
results = {"single": torch.vstack([output["single"] for output in outputs]).cpu()}
sessions_times_radios, snapshots, ntheta = results["single"].shape
sessions = sessions_times_radios // 2
radios = 2
results["single"] = results["single"].reshape(sessions, radios, snapshots, ntheta)

if "paired" in outputs[0]:
results["paired"] = torch.vstack([output["paired"] for output in outputs]).cpu()
results["paired"] = results["paired"].reshape(
sessions, radios, snapshots, ntheta
)
return results
1 change: 1 addition & 0 deletions spf/scripts/create_empirical_p_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import tqdm
from matplotlib import pyplot as plt

from spf.dataset.spf_dataset import v5spfdataset
Expand Down
78 changes: 78 additions & 0 deletions spf/scripts/create_inference_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import argparse
from functools import partial
from multiprocessing import Pool

import tqdm

from spf.model_training_and_inference.models.single_point_networks_inference import (
get_inference_on_ds,
)

if __name__ == "__main__":

def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-fn",
type=str,
required=True,
)
parser.add_argument(
"--checkpoint-fn",
type=str,
required=True,
)
parser.add_argument(
"--inference-cache",
type=str,
required=True,
)
parser.add_argument("--parallel", type=int, required=False, default=16)

parser.add_argument("--precompute-cache", type=str, required=True)
parser.add_argument("--workers", type=int, required=False, default=0)
parser.add_argument("--device", type=str, required=False, default="cuda")
parser.add_argument(
"-d",
"--datasets",
type=str,
help="dataset prefixes",
nargs="+",
required=True,
)
parser.add_argument(
"--debug",
action=argparse.BooleanOptionalAction,
default=False,
)
return parser

parser = get_parser()
args = parser.parse_args()

run_fn = partial(
get_inference_on_ds,
config_fn=args.config_fn,
checkpoint_fn=args.checkpoint_fn,
device=args.device,
inference_cache=args.inference_cache,
batch_size=64,
workers=0,
precompute_cache=args.precompute_cache,
)

if args.debug:
results = list(
tqdm.tqdm(
map(run_fn, args.datasets),
total=len(args.datasets),
)
)
else:
with Pool(args.parallel) as pool: # cpu_count()) # cpu_count() // 4)
results = list(
tqdm.tqdm(
pool.imap(run_fn, args.datasets),
total=len(args.datasets),
)
)
Loading

0 comments on commit c1f047c

Please sign in to comment.