diff --git a/synd/core.py b/synd/core.py index fe6b4ac..7545a3c 100644 --- a/synd/core.py +++ b/synd/core.py @@ -1,8 +1,14 @@ """Functions for interacting with SynD models.""" import pickle +import numpy as np +from numpy.random import default_rng +try: + from packaging.version import parse +except ModuleNotFoundError: + from pkg_resources import parse_version as parse -def load_model(filename: str): +def load_model(filename: str, randomize: bool = True): """ Load a SynD model from a file. @@ -19,4 +25,9 @@ def load_model(filename: str): with open(filename, 'rb') as infile: model = pickle.load(infile) + if randomize: + model.rng = default_rng(seed=None) + + model.numpy_version_greater = parse(np.__version__) >= parse('1.25.0') + return model diff --git a/synd/models/base.py b/synd/models/base.py index 222dad1..45b41c2 100644 --- a/synd/models/base.py +++ b/synd/models/base.py @@ -3,6 +3,10 @@ import logging from rich.logging import RichHandler import pickle +try: + from packaging.version import parse +except ModuleNotFoundError: + from pkg_resources import parse_version as parse logger = logging.getLogger(__name__) @@ -21,6 +25,7 @@ class BaseSynDModel(ABC): def __init__(self): self.logger = logger + self.numpy_version_greater = parse(numpy.__version__) >= parse('1.25.0') def serialize(self): """ diff --git a/synd/models/discrete/markov.py b/synd/models/discrete/markov.py index 537a39d..12b9804 100644 --- a/synd/models/discrete/markov.py +++ b/synd/models/discrete/markov.py @@ -35,6 +35,8 @@ def __init__(self, transition_matrix: ArrayLike, backmapper: Callable[[int], Arr self.rng = np.random.default_rng(seed=seed) + self.numpy_version_greater = Version(numpy.__version__) >= Version('1.25.0') + self.cumulative_probabilities = np.cumsum(self.transition_matrix, axis=1) self.logger.info(f"Discrete Markov model created with {self.n_states} states successfully created") @@ -117,7 +119,10 @@ def generate_trajectory(self, initial_states: ArrayLike, n_steps: int) -> ArrayL trajectories[:, 0] = initial_states - probabilities = self.rng.random(size=(n_trajectories, n_steps - 1)) + if self.numpy_version_greater: + probabilities = np.asarray([generator.random(n_steps -1) for generator in self.rng.spawn(n_trajectories)]) + else: + probabilities = np.asarray([np.random.default_rng(seed=seed).random(n_steps -1) for seed in self.rng.bit_generator._seed_seq.spawn(n_trajectories)]) for istep in range(1, n_steps): current_states = trajectories[:, istep - 1] diff --git a/synd/westpa/propagator.py b/synd/westpa/propagator.py index 39dd085..bd94b06 100644 --- a/synd/westpa/propagator.py +++ b/synd/westpa/propagator.py @@ -174,14 +174,18 @@ def __init__(self, rc: westpa.rc = None): rc_parameters = rc.config.get(['west', 'propagation', 'parameters']) self.topology = md.load(rc_parameters['topology']) + # Determine seed and whether to randomize RNG + self.rng_seed = rc.config.get(['west', 'propagation', 'parameters', 'rng_seed'], None) + self.randomize = rc.config.get(['west', 'propagation', 'parameters', 'randomize'], True) # Default to randomize RNG + if 'synd_model' in rc_parameters.keys(): model_path = rc_parameters['synd_model'] - self.synd_model = synd.core.load_model(model_path) + self.synd_model = synd.core.load_model(model_path, self.randomize) else: pcoord_map_path = rc_parameters['pcoord_map'] with open(pcoord_map_path, 'rb') as inf: pcoord_map = pickle.load(inf) - if type(pcoord_map) is dict: + if isinstance(pcoord_map, dict): backmapper = pcoord_map.get else: backmapper = pcoord_map @@ -196,7 +200,7 @@ def __init__(self, rc: westpa.rc = None): self.synd_model = MarkovGenerator( transition_matrix=self.transition_matrix, backmapper=backmapper, - seed=None + seed=self.rng_seed ) # Our dynamics are propagated in the discrete space, which is recorded only in auxdata. After completing an @@ -232,9 +236,11 @@ def propagate(self, segments): initial_points = np.empty(n_segs, dtype=self.coord_dtype) for iseg, segment in enumerate(segments): - initial_points[iseg] = get_segment_parent_index(segment) + if self.randomize: # Randomize RNG everytime you propagate + self.synd_model.rng = np.random.default_rng(seed=self.rng_seed) + new_trajectories = self.synd_model.generate_trajectory( initial_states=initial_points, n_steps=self.coord_len diff --git a/tests/test_synd.py b/tests/test_synd.py index ae9554a..083a66a 100644 --- a/tests/test_synd.py +++ b/tests/test_synd.py @@ -69,10 +69,12 @@ def test_saving_loading_markov_generator(self): self.synmd_model.save("simple_synmd_model.dat") - loaded_model = load_model("simple_synmd_model.dat") + loaded_model = load_model("simple_synmd_model.dat", randomize=True) assert isinstance(loaded_model, MarkovGenerator) + assert loaded_model.rng.random() != self.synmd_model.rng.random() + os.remove("simple_synmd_model.dat") def test_markov_trajectory_generation(self):