Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring to split run methods (*mc, snpe) #42

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
117 changes: 92 additions & 25 deletions sbibm/algorithms/sbi/mcabc.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,16 @@
from typing import Optional, Tuple
import pickle
from typing import Optional, Tuple, Union

import torch
from sbi.inference import MCABC
from sbi.utils import KDEWrapper
from torch import Tensor

import sbibm
from sbibm.tasks.task import Task
from sbibm.utils.io import save_tensor_to_csv


def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
num_top_samples: Optional[int] = 100,
quantile: Optional[float] = None,
eps: Optional[float] = None,
distance: str = "l2",
batch_size: int = 1000,
save_distances: bool = False,
kde_bandwidth: Optional[str] = "cv",
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs REJ-ABC from `sbi`
__DOCSTRING__ = """Runs REJ-ABC from `sbi`

Choose one of `num_top_samples`, `quantile`, `eps`.

Expand All @@ -45,22 +29,52 @@ def run(
kde_bandwidth: If not None, will resample using KDE when necessary, set
e.g. to "cv" for cross-validated bandwidth selection
sass: If True, summary statistics are learned as in
Fearnhead & Prangle 2012.
Fearnhead & Prangle 2012
https://doi.org/10.1111/j.1467-9868.2011.01010.x
sass_fraction: Fraction of simulation budget to use for sass.
sass_feature_expansion_degree: Degree of polynomial expansion of the summary
statistics.
lra: If True, posterior samples are adjusted with
linear regression as in Beaumont et al. 2002.
linear regression as in Beaumont et al. 2002,
https://doi.org/10.1093/genetics/162.4.2025

"""


def build_posterior(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
num_top_samples: Optional[int] = 100,
quantile: Optional[float] = None,
eps: Optional[float] = None,
distance: str = "l2",
batch_size: int = 1000,
save_distances: bool = False,
kde_bandwidth: Optional[str] = "cv",
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
) -> Tuple[
Union[Tuple[Tensor, dict], Tuple[KDEWrapper, dict], Tensor, KDEWrapper], dict
]:
f"""
build_posterior method creating the inferred posterior object
{__DOCSTRING__}

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
posterior wrapper, summary dictionary
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)

assert not (num_top_samples is None and quantile is None and eps is None)

log = sbibm.get_logger(__name__)
log.info(f"Running REJ-ABC")
log.info(f"Building REJ-ABC posterior")

prior = task.get_prior_dist()
simulator = task.get_simulator(max_calls=num_simulations)
Expand Down Expand Up @@ -103,6 +117,59 @@ def run(
if save_distances:
save_tensor_to_csv("distances.csv", summary["distances"])

return output, summary


def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
num_top_samples: Optional[int] = 100,
quantile: Optional[float] = None,
eps: Optional[float] = None,
distance: str = "l2",
batch_size: int = 1000,
save_distances: bool = False,
kde_bandwidth: Optional[str] = "cv",
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
posterior_path: Optional[str] = "",
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
f"""
{__DOCSTRING__}
posterior_path: filesystem location where to store the posterior under
(if None, posterior is not saved)

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)
assert not (num_top_samples is None and quantile is None and eps is None)

inkwargs = {k: v for k, v in locals().items() if "posterior_path" not in k}

log = sbibm.get_logger(__name__)
log.info(f"Running REJ-ABC")
simulator = task.get_simulator(max_calls=num_simulations)

output, summary = build_posterior(**inkwargs)
kde = kde_bandwidth is not None

if posterior_path:
if not kde:
log.info(
f"unable to save posterior as non was created, kde = {kde, kde_bandwidth}"
)
elif posterior_path is not None:
log.info(f"storing posterior at {posterior_path}")
with open(posterior_path, "wb") as ofile:
pickle.dump(output, ofile)

if kde:
kde_posterior = output
samples = kde_posterior.sample(num_samples)
Expand Down
10 changes: 9 additions & 1 deletion sbibm/algorithms/sbi/sl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
import math
import pickle
from typing import Any, Dict, Optional

import torch
Expand Down Expand Up @@ -68,6 +68,7 @@ def run(
mcmc_method: str = "slice_np",
mcmc_parameters: Dict[str, Any] = {},
diag_eps: float = 0.0,
posterior_path: Optional[str] = "",
) -> (torch.Tensor, int, Optional[torch.Tensor]):
"""Runs (S)NLE from `sbi`

Expand All @@ -82,6 +83,8 @@ def run(
mcmc_method: MCMC method
mcmc_parameters: MCMC parameters
diag_eps: Epsilon applied to diagonal
posterior_path: filesystem location where to store the posterior under
(if None, posterior is not saved)

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
Expand Down Expand Up @@ -121,6 +124,11 @@ def run(

posterior = wrap_posterior(posterior, transforms)

if posterior_path:
log.info(f"storing posterior at {posterior_path}")
with open(posterior_path, "wb") as ofile:
pickle.dump(posterior, ofile)

# assert simulator.num_simulations == num_simulations

samples = posterior.sample((num_samples,)).detach()
Expand Down
137 changes: 103 additions & 34 deletions sbibm/algorithms/sbi/smcabc.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,18 @@
from typing import Optional, Tuple
import pickle
from typing import Optional, Tuple, Union

import pandas as pd
import torch
from sbi.inference import SMCABC
from sklearn.linear_model import LinearRegression
from sbi.utils import KDEWrapper
from torch import Tensor

import sbibm
from sbibm.tasks.task import Task

from .utils import clip_int


def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
population_size: Optional[int] = None,
distance: str = "l2",
epsilon_decay: float = 0.2,
distance_based_decay: bool = True,
ess_min: Optional[float] = None,
initial_round_factor: int = 5,
batch_size: int = 1000,
kernel: str = "gaussian",
kernel_variance_scale: float = 0.5,
use_last_pop_samples: bool = True,
algorithm_variant: str = "C",
save_summary: bool = False,
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
lra_sample_weights: bool = True,
kde_bandwidth: Optional[str] = "cv",
kde_sample_weights: bool = False,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
"""Runs SMC-ABC from `sbi`
__DOCSTRING__ = """Runs SMC-ABC from `sbi`

SMC-ABC supports two different ways of scheduling epsilon:
1) Exponential decay: eps_t+1 = epsilon_decay * eps_t
Expand Down Expand Up @@ -75,22 +50,56 @@ def run(
sass_feature_expansion_degree: Degree of polynomial expansion of the summary
statistics.
lra: If True, posterior samples are adjusted with
linear regression as in Beaumont et al. 2002.
linear regression as in Beaumont et al. 2002,
https://doi.org/10.1093/genetics/162.4.2025
lra_sample_weights: Whether to weigh LRA samples
kde_bandwidth: If not None, will resample using KDE when necessary, set
e.g. to "cv" for cross-validated bandwidth selection
kde_sample_weights: Whether to weigh KDE samples


"""


def build_posterior(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
population_size: Optional[int] = None,
distance: str = "l2",
epsilon_decay: float = 0.2,
distance_based_decay: bool = True,
ess_min: Optional[float] = None,
initial_round_factor: int = 5,
batch_size: int = 1000,
kernel: str = "gaussian",
kernel_variance_scale: float = 0.5,
use_last_pop_samples: bool = True,
algorithm_variant: str = "C",
save_summary: bool = False,
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
lra_sample_weights: bool = True,
kde_bandwidth: Optional[str] = "cv",
kde_sample_weights: bool = False,
) -> Tuple[
Union[Tuple[Tensor, dict], Tuple[KDEWrapper, dict], Tensor, KDEWrapper], dict
]:
f"""
build_posterior method creating the inferred posterior object
{__DOCSTRING__}

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
posterior wrapper, summary dictionary
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)

log = sbibm.get_logger(__name__)
smc_papers = dict(A="Toni 2010", B="Sisson et al. 2007", C="Beaumont et al. 2009")
log.info(f"Running SMC-ABC as in {smc_papers[algorithm_variant]}.")

prior = task.get_prior_dist()
simulator = task.get_simulator(max_calls=num_simulations)
Expand Down Expand Up @@ -151,6 +160,66 @@ def run(

assert simulator.num_simulations == num_simulations

return output, summary


def run(
task: Task,
num_samples: int,
num_simulations: int,
num_observation: Optional[int] = None,
observation: Optional[torch.Tensor] = None,
population_size: Optional[int] = None,
distance: str = "l2",
epsilon_decay: float = 0.2,
distance_based_decay: bool = True,
ess_min: Optional[float] = None,
initial_round_factor: int = 5,
batch_size: int = 1000,
kernel: str = "gaussian",
kernel_variance_scale: float = 0.5,
use_last_pop_samples: bool = True,
algorithm_variant: str = "C",
save_summary: bool = False,
sass: bool = False,
sass_fraction: float = 0.5,
sass_feature_expansion_degree: int = 3,
lra: bool = False,
lra_sample_weights: bool = True,
kde_bandwidth: Optional[str] = "cv",
kde_sample_weights: bool = False,
posterior_path: Optional[str] = "",
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
f"""
{__DOCSTRING__}
posterior_path: filesystem location where to store the posterior under
(if None, posterior is not saved)

Returns:
Samples from posterior, number of simulator calls, log probability of true params if computable
"""
assert not (num_observation is None and observation is None)
assert not (num_observation is not None and observation is not None)

inkwargs = {k: v for k, v in locals().items() if "posterior_path" not in k}

log = sbibm.get_logger(__name__)
smc_papers = dict(A="Toni 2010", B="Sisson et al. 2007", C="Beaumont et al. 2009")
log.info(f"Building SMC-ABC Posterior as in {smc_papers[algorithm_variant]}.")

simulator = task.get_simulator(max_calls=num_simulations)
kde = kde_bandwidth is not None
output, summary = build_posterior(**inkwargs)
if posterior_path:
if not kde:
log.info(
f"unable to save posterior as non was created, kde = {kde, kde_bandwidth}"
)
elif posterior_path is not None:
log.info(f"storing posterior at {posterior_path}")
with open(posterior_path, "wb") as ofile:
pickle.dump(output, ofile)

# Return samples from kde or raw samples.
if kde:
kde_posterior = output
Expand Down
Loading