From bd232d2b20cb613f6e4374182648ba9e168522fb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 19 Nov 2024 14:19:12 +0100 Subject: [PATCH] Allow passing compile_kwargs to step inner functions --- pymc/sampling/mcmc.py | 16 +++++++++++++- pymc/step_methods/arraystep.py | 4 ++++ pymc/step_methods/metropolis.py | 30 ++++++++++++++++++++------- pymc/step_methods/slicer.py | 5 ++++- tests/helpers.py | 4 +++- tests/step_methods/test_metropolis.py | 12 +++++------ 6 files changed, 55 insertions(+), 16 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index defa5b5383..11cff18bb6 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -104,6 +104,7 @@ def instantiate_steppers( *, step_kwargs: dict[str, dict] | None = None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, ) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. @@ -146,6 +147,7 @@ def instantiate_steppers( vars=vars, model=model, initial_point=initial_point, + compile_kwargs=compile_kwargs, **kwargs, ) steps.append(step) @@ -434,6 +436,7 @@ def sample( callback=None, mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", + compile_kwargs: dict | None = None, **kwargs, ) -> InferenceData: ... @@ -466,6 +469,7 @@ def sample( mp_ctx=None, model: Model | None = None, blas_cores: int | None | Literal["auto"] = "auto", + compile_kwargs: dict | None = None, **kwargs, ) -> MultiTrace: ... @@ -497,6 +501,7 @@ def sample( mp_ctx=None, blas_cores: int | None | Literal["auto"] = "auto", model: Model | None = None, + compile_kwargs: dict | None = None, **kwargs, ) -> InferenceData | MultiTrace: r"""Draw samples from the posterior using the given step methods. @@ -598,6 +603,9 @@ def sample( See multiprocessing documentation for details. model : Model (optional if in ``with`` context) Model to sample from. The model needs to have free random variables. + compile_kwargs: dict, optional + Dictionary with keyword argument to pass to the functions compiled by the step methods. + Returns ------- @@ -795,6 +803,7 @@ def joined_blas_limiter(): jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, + compile_kwargs=compile_kwargs, **kwargs, ) else: @@ -814,6 +823,7 @@ def joined_blas_limiter(): selected_steps=selected_steps, step_kwargs=kwargs, initial_point=initial_points[0], + compile_kwargs=compile_kwargs, ) if isinstance(step, list): step = CompoundStep(step) @@ -1390,6 +1400,7 @@ def init_nuts( jitter_max_retries: int = 10, tune: int | None = None, initvals: StartDict | Sequence[StartDict | None] | None = None, + compile_kwargs: dict | None = None, **kwargs, ) -> tuple[Sequence[PointType], NUTS]: """Set up the mass matrix initialization for NUTS. @@ -1466,6 +1477,9 @@ def init_nuts( if init == "auto": init = "jitter+adapt_diag" + if compile_kwargs is None: + compile_kwargs = {} + random_seed_list = _get_seeds_per_chain(random_seed, chains) _log.info(f"Initializing NUTS using {init}...") @@ -1477,7 +1491,7 @@ def init_nuts( pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), ] - logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True) + logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True, **compile_kwargs) logp_dlogp_func.trust_input = True initial_points = _init_jitter( model, diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 7ddfb65f06..d15b14499c 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -182,16 +182,20 @@ def __init__( logp_dlogp_func=None, rng: RandomGenerator = None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, **pytensor_kwargs, ): model = modelcontext(model) if logp_dlogp_func is None: + if compile_kwargs is None: + compile_kwargs = {} logp_dlogp_func = model.logp_dlogp_function( vars, dtype=dtype, ravel_inputs=True, initial_point=initial_point, + **compile_kwargs, **pytensor_kwargs, ) logp_dlogp_func.trust_input = True diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 8aed4bfd67..b6d82243a1 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -162,6 +162,7 @@ def __init__( mode=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = False, ): """Create an instance of a Metropolis stepper. @@ -254,7 +255,7 @@ def __init__( self.mode = mode shared = pm.make_shared_replacements(initial_point, vars, model) - self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): @@ -432,6 +433,7 @@ def __init__( model=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -447,7 +449,9 @@ def __init__( if not all(v.dtype in pm.discrete_types for v in vars): raise ValueError("All variables must be Bernoulli for BinaryMetropolis") - super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) + if compile_kwargs is None: + compile_kwargs = {} + super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -554,6 +558,7 @@ def __init__( model=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -582,7 +587,10 @@ def __init__( if not all(v.dtype in pm.discrete_types for v in vars): raise ValueError("All variables must be binary for BinaryGibbsMetropolis") - super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) + if compile_kwargs is None: + compile_kwargs = {} + + super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -672,6 +680,7 @@ def __init__( model=None, rng: RandomGenerator = None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -728,7 +737,9 @@ def __init__( # that indicates whether a draw was done in a tuning phase. self.tune = True - super().__init__(vars, [model.compile_logp()], blocked=blocked, rng=rng) + if compile_kwargs is None: + compile_kwargs = {} + super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) def reset_tuning(self): # There are no tuning parameters in this step method. @@ -904,6 +915,7 @@ def __init__( mode=None, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = True, ): model = pm.modelcontext(model) @@ -939,7 +951,7 @@ def __init__( self.mode = mode shared = pm.make_shared_replacements(initial_point, vars, model) - self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) super().__init__(vars, shared, blocked=blocked, rng=rng) def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: @@ -1073,6 +1085,7 @@ def __init__( tune_drop_fraction: float = 0.9, model=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, mode=None, rng=None, blocked: bool = True, @@ -1122,7 +1135,7 @@ def __init__( self.mode = mode shared = pm.make_shared_replacements(initial_point, vars, model) - self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared) + self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) super().__init__(vars, shared, blocked=blocked, rng=rng) def reset_tuning(self): @@ -1213,6 +1226,7 @@ def delta_logp( logp: pt.TensorVariable, vars: list[pt.TensorVariable], shared: dict[pt.TensorVariable, pt.sharedvar.TensorSharedVariable], + compile_kwargs: dict | None, ) -> pytensor.compile.Function: [logp0], inarray0 = join_nonshared_inputs( point=point, outputs=[logp], inputs=vars, shared_inputs=shared @@ -1225,6 +1239,8 @@ def delta_logp( # Replace any potential duplicated RNG nodes (logp1,) = replace_rng_nodes((logp1,)) - f = compile_pymc([inarray1, inarray0], logp1 - logp0) + if compile_kwargs is None: + compile_kwargs = {} + f = compile_pymc([inarray1, inarray0], logp1 - logp0, **compile_kwargs) f.trust_input = True return f diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 57b25e9512..b84674390d 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -86,6 +86,7 @@ def __init__( iter_limit=np.inf, rng=None, initial_point: PointType | None = None, + compile_kwargs: dict | None = None, blocked: bool = False, # Could be true since tuning is independent across dims? ): model = modelcontext(model) @@ -106,7 +107,9 @@ def __init__( [logp], raveled_inp = join_nonshared_inputs( point=initial_point, outputs=[model.logp()], inputs=vars, shared_inputs=shared ) - self.logp = compile_pymc([raveled_inp], logp) + if compile_kwargs is None: + compile_kwargs = {} + self.logp = compile_pymc([raveled_inp], logp, **compile_kwargs) self.logp.trust_input = True super().__init__(vars, shared, blocked=blocked, rng=rng) diff --git a/tests/helpers.py b/tests/helpers.py index ba481d6763..e4b6248930 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -26,6 +26,7 @@ import pytensor from numpy.testing import assert_array_less +from pytensor.compile.mode import Mode from pytensor.gradient import verify_grad as at_verify_grad import pymc as pm @@ -198,10 +199,11 @@ def continuous_steps(self, step, step_kwargs): c1 = pm.HalfNormal("c1") c2 = pm.HalfNormal("c2") - # Test methods can handle initial_point + # Test methods can handle initial_point and compile_kwargs step_kwargs.setdefault( "initial_point", {"c1_log__": np.array(0.5), "c2_log__": np.array(0.9)} ) + step_kwargs.setdefault("compile_kwargs", {"mode": Mode(linker="py", optimizer=None)}) with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): assert [m.rvs_to_values[c1]] == step([c1], **step_kwargs).vars assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set( diff --git a/tests/step_methods/test_metropolis.py b/tests/step_methods/test_metropolis.py index 259b6e0546..a01e75506b 100644 --- a/tests/step_methods/test_metropolis.py +++ b/tests/step_methods/test_metropolis.py @@ -22,6 +22,8 @@ import pytensor import pytest +from pytensor.compile.mode import Mode + import pymc as pm from pymc.step_methods.metropolis import ( @@ -368,18 +370,16 @@ def test_discrete_steps(self, step): d1 = pm.Bernoulli("d1", p=0.5) d2 = pm.Bernoulli("d2", p=0.5) - # Test it can take initial_point as a kwarg + # Test it can take initial_point, and compile_kwargs as a kwarg step_kwargs = { "initial_point": { "d1": np.array(0, dtype="int64"), "d2": np.array(1, dtype="int64"), }, + "compile_kwargs": {"mode": Mode(linker="py", optimizer=None)}, } - with pytensor.config.change_flags(mode=fast_unstable_sampling_mode): - assert [m.rvs_to_values[d1]] == step([d1]).vars - assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set( - step([d1, d2]).vars - ) + assert [m.rvs_to_values[d1]] == step([d1]).vars + assert {m.rvs_to_values[d1], m.rvs_to_values[d2]} == set(step([d1, d2]).vars) @pytest.mark.parametrize( "step, step_kwargs", [(Metropolis, {}), (DEMetropolis, {}), (DEMetropolisZ, {})]