Skip to content

Commit

Permalink
Allow passing compile_kwargs to step inner functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 27, 2024
1 parent 8ac7108 commit bd232d2
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 16 deletions.
16 changes: 15 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def instantiate_steppers(
*,
step_kwargs: dict[str, dict] | None = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
) -> Step | list[Step]:
"""Instantiate steppers assigned to the model variables.
Expand Down Expand Up @@ -146,6 +147,7 @@ def instantiate_steppers(
vars=vars,
model=model,
initial_point=initial_point,
compile_kwargs=compile_kwargs,
**kwargs,
)
steps.append(step)
Expand Down Expand Up @@ -434,6 +436,7 @@ def sample(
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
**kwargs,
) -> InferenceData: ...

Expand Down Expand Up @@ -466,6 +469,7 @@ def sample(
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
compile_kwargs: dict | None = None,
**kwargs,
) -> MultiTrace: ...

Expand Down Expand Up @@ -497,6 +501,7 @@ def sample(
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
compile_kwargs: dict | None = None,
**kwargs,
) -> InferenceData | MultiTrace:
r"""Draw samples from the posterior using the given step methods.
Expand Down Expand Up @@ -598,6 +603,9 @@ def sample(
See multiprocessing documentation for details.
model : Model (optional if in ``with`` context)
Model to sample from. The model needs to have free random variables.
compile_kwargs: dict, optional
Dictionary with keyword argument to pass to the functions compiled by the step methods.
Returns
-------
Expand Down Expand Up @@ -795,6 +803,7 @@ def joined_blas_limiter():
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
compile_kwargs=compile_kwargs,
**kwargs,
)
else:
Expand All @@ -814,6 +823,7 @@ def joined_blas_limiter():
selected_steps=selected_steps,
step_kwargs=kwargs,
initial_point=initial_points[0],
compile_kwargs=compile_kwargs,
)
if isinstance(step, list):
step = CompoundStep(step)
Expand Down Expand Up @@ -1390,6 +1400,7 @@ def init_nuts(
jitter_max_retries: int = 10,
tune: int | None = None,
initvals: StartDict | Sequence[StartDict | None] | None = None,
compile_kwargs: dict | None = None,
**kwargs,
) -> tuple[Sequence[PointType], NUTS]:
"""Set up the mass matrix initialization for NUTS.
Expand Down Expand Up @@ -1466,6 +1477,9 @@ def init_nuts(
if init == "auto":
init = "jitter+adapt_diag"

if compile_kwargs is None:
compile_kwargs = {}

random_seed_list = _get_seeds_per_chain(random_seed, chains)

_log.info(f"Initializing NUTS using {init}...")
Expand All @@ -1477,7 +1491,7 @@ def init_nuts(
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
]

logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True)
logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs)
logp_dlogp_func.trust_input = True
initial_points = _init_jitter(
model,
Expand Down
4 changes: 4 additions & 0 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,20 @@ def __init__(
logp_dlogp_func=None,
rng: RandomGenerator = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
**pytensor_kwargs,
):
model = modelcontext(model)

if logp_dlogp_func is None:
if compile_kwargs is None:
compile_kwargs = {}
logp_dlogp_func = model.logp_dlogp_function(
vars,
dtype=dtype,
ravel_inputs=True,
initial_point=initial_point,
**compile_kwargs,
**pytensor_kwargs,
)
logp_dlogp_func.trust_input = True
Expand Down
30 changes: 23 additions & 7 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
mode=None,
rng=None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = False,
):
"""Create an instance of a Metropolis stepper.
Expand Down Expand Up @@ -254,7 +255,7 @@ def __init__(
self.mode = mode

shared = pm.make_shared_replacements(initial_point, vars, model)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, blocked=blocked, rng=rng)

def reset_tuning(self):
Expand Down Expand Up @@ -432,6 +433,7 @@ def __init__(
model=None,
rng=None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = True,
):
model = pm.modelcontext(model)
Expand All @@ -447,7 +449,9 @@ def __init__(
if not all(v.dtype in pm.discrete_types for v in vars):
raise ValueError("All variables must be Bernoulli for BinaryMetropolis")

super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
if compile_kwargs is None:
compile_kwargs = {}
super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng)

def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
logp = args[0]
Expand Down Expand Up @@ -554,6 +558,7 @@ def __init__(
model=None,
rng=None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = True,
):
model = pm.modelcontext(model)
Expand Down Expand Up @@ -582,7 +587,10 @@ def __init__(
if not all(v.dtype in pm.discrete_types for v in vars):
raise ValueError("All variables must be binary for BinaryGibbsMetropolis")

super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
if compile_kwargs is None:
compile_kwargs = {}

super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng)

def reset_tuning(self):
# There are no tuning parameters in this step method.
Expand Down Expand Up @@ -672,6 +680,7 @@ def __init__(
model=None,
rng: RandomGenerator = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = True,
):
model = pm.modelcontext(model)
Expand Down Expand Up @@ -728,7 +737,9 @@ def __init__(
# that indicates whether a draw was done in a tuning phase.
self.tune = True

super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng)
if compile_kwargs is None:
compile_kwargs = {}
super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng)

def reset_tuning(self):
# There are no tuning parameters in this step method.
Expand Down Expand Up @@ -904,6 +915,7 @@ def __init__(
mode=None,
rng=None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = True,
):
model = pm.modelcontext(model)
Expand Down Expand Up @@ -939,7 +951,7 @@ def __init__(
self.mode = mode

shared = pm.make_shared_replacements(initial_point, vars, model)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, blocked=blocked, rng=rng)

def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
Expand Down Expand Up @@ -1073,6 +1085,7 @@ def __init__(
tune_drop_fraction: float = 0.9,
model=None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
mode=None,
rng=None,
blocked: bool = True,
Expand Down Expand Up @@ -1122,7 +1135,7 @@ def __init__(
self.mode = mode

shared = pm.make_shared_replacements(initial_point, vars, model)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared)
self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs)
super().__init__(vars, shared, blocked=blocked, rng=rng)

def reset_tuning(self):
Expand Down Expand Up @@ -1213,6 +1226,7 @@ def delta_logp(
logp: pt.TensorVariable,
vars: list[pt.TensorVariable],
shared: dict[pt.TensorVariable, pt.sharedvar.TensorSharedVariable],
compile_kwargs: dict | None,
) -> pytensor.compile.Function:
[logp0], inarray0 = join_nonshared_inputs(
point=point, outputs=[logp], inputs=vars, shared_inputs=shared
Expand All @@ -1225,6 +1239,8 @@ def delta_logp(
# Replace any potential duplicated RNG nodes
(logp1,) = replace_rng_nodes((logp1,))

f = compile_pymc([inarray1, inarray0], logp1 - logp0)
if compile_kwargs is None:
compile_kwargs = {}
f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs)
f.trust_input = True
return f
5 changes: 4 additions & 1 deletion pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
iter_limit=np.inf,
rng=None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
blocked: bool = False, # Could be true since tuning is independent across dims?
):
model = modelcontext(model)
Expand All @@ -106,7 +107,9 @@ def __init__(
[logp], raveled_inp = join_nonshared_inputs(
point=initial_point, outputs=[model.logp()], inputs=vars, shared_inputs=shared
)
self.logp = compile_pymc([raveled_inp], logp)
if compile_kwargs is None:
compile_kwargs = {}
self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs)
self.logp.trust_input = True

super().__init__(vars, shared, blocked=blocked, rng=rng)
Expand Down
4 changes: 3 additions & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytensor

from numpy.testing import assert_array_less
from pytensor.compile.mode import Mode
from pytensor.gradient import verify_grad as at_verify_grad

import pymc as pm
Expand Down Expand Up @@ -198,10 +199,11 @@ def continuous_steps(self, step, step_kwargs):
c1 = pm.HalfNormal("c1")
c2 = pm.HalfNormal("c2")

# Test methods can handle initial_point
# Test methods can handle initial_point and compile_kwargs
step_kwargs.setdefault(
"initial_point", {"c1_log__": np.array(0.5), "c2_log__": np.array(0.9)}
)
step_kwargs.setdefault("compile_kwargs", {"mode": Mode(linker="py", optimizer=None)})
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(
Expand Down
12 changes: 6 additions & 6 deletions tests/step_methods/test_metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import pytensor
import pytest

from pytensor.compile.mode import Mode

import pymc as pm

from pymc.step_methods.metropolis import (
Expand Down Expand Up @@ -368,18 +370,16 @@ def test_discrete_steps(self, step):
d1 = pm.Bernoulli("d1", p=0.5)
d2 = pm.Bernoulli("d2", p=0.5)

# Test it can take initial_point as a kwarg
# Test it can take initial_point, and compile_kwargs as a kwarg
step_kwargs = {
"initial_point": {
"d1": np.array(0, dtype="int64"),
"d2": np.array(1, dtype="int64"),
},
"compile_kwargs": {"mode": Mode(linker="py", optimizer=None)},
}
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
assert [m.rvs_to_values[d1]] == step([d1]).vars
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(
step([d1, d2]).vars
)
assert [m.rvs_to_values[d1]] == step([d1]).vars
assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars)

@pytest.mark.parametrize(
"step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]
Expand Down

0 comments on commit bd232d2

Please sign in to comment.