diff --git a/evofr/infer/MCMC_handler.py b/evofr/infer/MCMC_handler.py index ba7541e..f3d89c9 100644 --- a/evofr/infer/MCMC_handler.py +++ b/evofr/infer/MCMC_handler.py @@ -1,8 +1,7 @@ import pickle from typing import Callable, Dict, Optional, Type -from jax import random -from jax._src.random import KeyArray +from jax import random, Array from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.mcmc import MCMCKernel @@ -10,7 +9,7 @@ class MCMCHandler: def __init__( self, - rng_key: Optional[KeyArray] = None, + rng_key: Optional[Array] = None, kernel: Optional[Type[MCMCKernel]] = None, **kernel_kwargs ): diff --git a/evofr/infer/SVI_handler.py b/evofr/infer/SVI_handler.py index 8abfea3..3666ec4 100644 --- a/evofr/infer/SVI_handler.py +++ b/evofr/infer/SVI_handler.py @@ -2,8 +2,7 @@ from typing import Any, Callable, Optional import jax.example_libraries.optimizers as optimizers -from jax import random -from jax._src.random import KeyArray +from jax import random, Array from numpyro.infer import SVI, Predictive, Trace_ELBO from numpyro.infer.autoguide import AutoGuide from numpyro.infer.svi import SVIState @@ -14,7 +13,7 @@ class SVIHandler: def __init__( self, - rng_key: Optional[KeyArray] = None, + rng_key: Optional[Array] = None, loss: Optional[Trace_ELBO] = None, optimizer: Optional[Optimizer] = None, ):