diff --git a/CHANGELOG.md b/CHANGELOG.md index aa682cf..97282e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Bugfix 🐛: +- Fix mutable properties of `mh.StateTraj` and `mh.LumpedStateTraj`, #43 + +### Other changes: +- Improved performance of `mh.LumpedStateTraj.microtrajs` ## [1.1.0] - 2023-11-03 diff --git a/src/msmhelper/statetraj.py b/src/msmhelper/statetraj.py index dfc83ad..d820f03 100644 --- a/src/msmhelper/statetraj.py +++ b/src/msmhelper/statetraj.py @@ -60,11 +60,14 @@ def __init__(self, trajs): # get number of states self._states = mh.utils.unique(self._trajs) + # enforce true copy of trajs + if np.array_equal(self._states, np.arange(self.nstates)): + self._trajs = [traj.copy() for traj in self._trajs] # shift to indices - if np.array_equal(self._states, np.arange(1, self.nstates + 1)): + elif np.array_equal(self._states, np.arange(1, self.nstates + 1)): self._states = np.arange(1, self.nstates + 1) self._trajs = [traj - 1 for traj in self._trajs] - elif not np.array_equal(self._states, np.arange(self.nstates)): + else: # not np.array_equal(self._states, np.arange(self.nstates)): self._trajs, self._states = mh.utils.rename_by_index( self._trajs, return_permutation=True, @@ -80,7 +83,7 @@ def states(self): Numpy array holding active set of states. """ - return self._states + return self._states.copy() @property def nstates(self): @@ -108,7 +111,7 @@ def ntrajs(self): @property def nframes(self): - """Return cummulative length of all trajectories. + """Return cumulative length of all trajectories. Returns ------- @@ -131,7 +134,7 @@ def trajs(self): if np.array_equal(self.states, np.arange(1, self.nstates + 1)): return [traj + 1 for traj in self._trajs] if np.array_equal(self.states, np.arange(self.nstates)): - return self._trajs + return self.index_trajs return mh.shift_data( self._trajs, np.arange(self.nstates), @@ -160,7 +163,7 @@ def index_trajs(self): List of ndarrays holding the input data. """ - return self._trajs + return [traj.copy() for traj in self._trajs] @property def index_trajs_flatten(self): @@ -338,7 +341,7 @@ def states(self): Numpy array holding active set of states. """ - return self._macrostates + return self._macrostates.copy() @property def nstates(self): @@ -362,8 +365,10 @@ def microstate_trajs(self): List of ndarrays holding the input data. """ - if np.array_equal(self.microstates, np.arange(self.nmicrostates)): - return self._trajs + if np.array_equal(self.microstates, np.arange(1, self.nstates + 1)): + return [traj + 1 for traj in self._trajs] + elif np.array_equal(self.microstates, np.arange(self.nmicrostates)): + return self.microstate_index_trajs return mh.shift_data( self._trajs, np.arange(self.nmicrostates), @@ -392,7 +397,7 @@ def microstate_index_trajs(self): List of ndarrays holding the microstate index trajectory. """ - return self._trajs + return [traj.copy() for traj in self._trajs] @property def microstate_index_trajs_flatten(self): @@ -448,7 +453,7 @@ def microstates(self): Numpy array holding active set of states. """ - return self._states + return self._states.copy() @property def nmicrostates(self): @@ -472,7 +477,7 @@ def state_assignment(self): Micro to macrostate assignment vector. """ - return self._state_assignment + return self._state_assignment.copy() @property def _state_assignment_idx(self): diff --git a/test/test_statetraj.py b/test/test_statetraj.py index 09202b7..172b1bf 100644 --- a/test/test_statetraj.py +++ b/test/test_statetraj.py @@ -96,6 +96,14 @@ def test_StateTraj_constructor(statetraj): lumpedTraj = LumpedStateTraj(statetraj, statetraj) assert lumpedTraj is StateTraj(lumpedTraj) + # check that immutable + assert traj._trajs[0] is not StateTraj(traj.trajs)._trajs[0] + # check for index trajs + assert ( + StateTraj(traj.index_trajs)._trajs[0] is not + StateTraj(traj.index_trajs)._trajs[0] + ) + def test_LumpedStateTraj_constructor(macrotraj, statetraj): """Test construction of object.""" @@ -127,6 +135,14 @@ def test_nstates(state_traj, statetraj, macro_traj, macrotraj): state_traj.nstates = 5 +def test_states(state_traj): + """Test immutability of states property.""" + assert state_traj.states is not state_traj.states + + with pytest.raises(AttributeError): + state_traj.states = state_traj.states + + def test_nframes(state_traj): """Test nframes property.""" assert state_traj.nframes == len(state_traj[0])