Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the Random Number Generator Bug from pickled objects #8

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
13 changes: 12 additions & 1 deletion synd/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
"""Functions for interacting with SynD models."""
import pickle
import numpy as np
from numpy.random import default_rng
try:
from packaging.version import parse
except ModuleNotFoundError:
from pkg_resources import parse_version as parse


def load_model(filename: str):
def load_model(filename: str, randomize: bool = True):
"""
Load a SynD model from a file.

Expand All @@ -19,4 +25,9 @@ def load_model(filename: str):
with open(filename, 'rb') as infile:
model = pickle.load(infile)

if randomize:
model.rng = default_rng(seed=None)

model.numpy_version_greater = parse(np.__version__) >= parse('1.25.0')

return model
5 changes: 5 additions & 0 deletions synd/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import logging
from rich.logging import RichHandler
import pickle
try:
from packaging.version import parse
except ModuleNotFoundError:
from pkg_resources import parse_version as parse

logger = logging.getLogger(__name__)

Expand All @@ -21,6 +25,7 @@ class BaseSynDModel(ABC):
def __init__(self):

self.logger = logger
self.numpy_version_greater = parse(numpy.__version__) >= parse('1.25.0')

def serialize(self):
"""
Expand Down
7 changes: 6 additions & 1 deletion synd/models/discrete/markov.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self, transition_matrix: ArrayLike, backmapper: Callable[[int], Arr

self.rng = np.random.default_rng(seed=seed)

self.numpy_version_greater = Version(numpy.__version__) >= Version('1.25.0')

self.cumulative_probabilities = np.cumsum(self.transition_matrix, axis=1)

self.logger.info(f"Discrete Markov model created with {self.n_states} states successfully created")
Expand Down Expand Up @@ -117,7 +119,10 @@ def generate_trajectory(self, initial_states: ArrayLike, n_steps: int) -> ArrayL

trajectories[:, 0] = initial_states

probabilities = self.rng.random(size=(n_trajectories, n_steps - 1))
if self.numpy_version_greater:
probabilities = np.asarray([generator.random(n_steps -1) for generator in self.rng.spawn(n_trajectories)])
else:
probabilities = np.asarray([np.random.default_rng(seed=seed).random(n_steps -1) for seed in self.rng.bit_generator._seed_seq.spawn(n_trajectories)])

for istep in range(1, n_steps):
current_states = trajectories[:, istep - 1]
Expand Down
14 changes: 10 additions & 4 deletions synd/westpa/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,18 @@ def __init__(self, rc: westpa.rc = None):
rc_parameters = rc.config.get(['west', 'propagation', 'parameters'])
self.topology = md.load(rc_parameters['topology'])

# Determine seed and whether to randomize RNG
self.rng_seed = rc.config.get(['west', 'propagation', 'parameters', 'rng_seed'], None)
self.randomize = rc.config.get(['west', 'propagation', 'parameters', 'randomize'], True) # Default to randomize RNG

if 'synd_model' in rc_parameters.keys():
model_path = rc_parameters['synd_model']
self.synd_model = synd.core.load_model(model_path)
self.synd_model = synd.core.load_model(model_path, self.randomize)
else:
pcoord_map_path = rc_parameters['pcoord_map']
with open(pcoord_map_path, 'rb') as inf:
pcoord_map = pickle.load(inf)
if type(pcoord_map) is dict:
if isinstance(pcoord_map, dict):
backmapper = pcoord_map.get
else:
backmapper = pcoord_map
Expand All @@ -196,7 +200,7 @@ def __init__(self, rc: westpa.rc = None):
self.synd_model = MarkovGenerator(
transition_matrix=self.transition_matrix,
backmapper=backmapper,
seed=None
seed=self.rng_seed
)

# Our dynamics are propagated in the discrete space, which is recorded only in auxdata. After completing an
Expand Down Expand Up @@ -232,9 +236,11 @@ def propagate(self, segments):
initial_points = np.empty(n_segs, dtype=self.coord_dtype)

for iseg, segment in enumerate(segments):

initial_points[iseg] = get_segment_parent_index(segment)

if self.randomize: # Randomize RNG everytime you propagate
self.synd_model.rng = np.random.default_rng(seed=self.rng_seed)

new_trajectories = self.synd_model.generate_trajectory(
initial_states=initial_points,
n_steps=self.coord_len
Expand Down
4 changes: 3 additions & 1 deletion tests/test_synd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def test_saving_loading_markov_generator(self):

self.synmd_model.save("simple_synmd_model.dat")

loaded_model = load_model("simple_synmd_model.dat")
loaded_model = load_model("simple_synmd_model.dat", randomize=True)

assert isinstance(loaded_model, MarkovGenerator)

assert loaded_model.rng.random() != self.synmd_model.rng.random()

os.remove("simple_synmd_model.dat")

def test_markov_trajectory_generation(self):
Expand Down