Skip to content

Commit

Permalink
Merge pull request #44 from moldyn/issue43
Browse files Browse the repository at this point in the history
Test for immutability.
  • Loading branch information
braniii authored Nov 10, 2023
2 parents 7b8722b + fb902dd commit 1b3b31d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0


## [Unreleased]
### Bugfix 🐛:
- Fix mutable properties of `mh.StateTraj` and `mh.LumpedStateTraj`, #43

### Other changes:
- Improved performance of `mh.LumpedStateTraj.microtrajs`


## [1.1.0] - 2023-11-03
Expand Down
29 changes: 17 additions & 12 deletions src/msmhelper/statetraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ def __init__(self, trajs):
# get number of states
self._states = mh.utils.unique(self._trajs)

# enforce true copy of trajs
if np.array_equal(self._states, np.arange(self.nstates)):
self._trajs = [traj.copy() for traj in self._trajs]
# shift to indices
if np.array_equal(self._states, np.arange(1, self.nstates + 1)):
elif np.array_equal(self._states, np.arange(1, self.nstates + 1)):
self._states = np.arange(1, self.nstates + 1)
self._trajs = [traj - 1 for traj in self._trajs]
elif not np.array_equal(self._states, np.arange(self.nstates)):
else: # not np.array_equal(self._states, np.arange(self.nstates)):
self._trajs, self._states = mh.utils.rename_by_index(
self._trajs,
return_permutation=True,
Expand All @@ -80,7 +83,7 @@ def states(self):
Numpy array holding active set of states.
"""
return self._states
return self._states.copy()

@property
def nstates(self):
Expand Down Expand Up @@ -108,7 +111,7 @@ def ntrajs(self):

@property
def nframes(self):
"""Return cummulative length of all trajectories.
"""Return cumulative length of all trajectories.
Returns
-------
Expand All @@ -131,7 +134,7 @@ def trajs(self):
if np.array_equal(self.states, np.arange(1, self.nstates + 1)):
return [traj + 1 for traj in self._trajs]
if np.array_equal(self.states, np.arange(self.nstates)):
return self._trajs
return self.index_trajs
return mh.shift_data(
self._trajs,
np.arange(self.nstates),
Expand Down Expand Up @@ -160,7 +163,7 @@ def index_trajs(self):
List of ndarrays holding the input data.
"""
return self._trajs
return [traj.copy() for traj in self._trajs]

@property
def index_trajs_flatten(self):
Expand Down Expand Up @@ -338,7 +341,7 @@ def states(self):
Numpy array holding active set of states.
"""
return self._macrostates
return self._macrostates.copy()

@property
def nstates(self):
Expand All @@ -362,8 +365,10 @@ def microstate_trajs(self):
List of ndarrays holding the input data.
"""
if np.array_equal(self.microstates, np.arange(self.nmicrostates)):
return self._trajs
if np.array_equal(self.microstates, np.arange(1, self.nstates + 1)):
return [traj + 1 for traj in self._trajs]
elif np.array_equal(self.microstates, np.arange(self.nmicrostates)):
return self.microstate_index_trajs
return mh.shift_data(
self._trajs,
np.arange(self.nmicrostates),
Expand Down Expand Up @@ -392,7 +397,7 @@ def microstate_index_trajs(self):
List of ndarrays holding the microstate index trajectory.
"""
return self._trajs
return [traj.copy() for traj in self._trajs]

@property
def microstate_index_trajs_flatten(self):
Expand Down Expand Up @@ -448,7 +453,7 @@ def microstates(self):
Numpy array holding active set of states.
"""
return self._states
return self._states.copy()

@property
def nmicrostates(self):
Expand All @@ -472,7 +477,7 @@ def state_assignment(self):
Micro to macrostate assignment vector.
"""
return self._state_assignment
return self._state_assignment.copy()

@property
def _state_assignment_idx(self):
Expand Down
16 changes: 16 additions & 0 deletions test/test_statetraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ def test_StateTraj_constructor(statetraj):
lumpedTraj = LumpedStateTraj(statetraj, statetraj)
assert lumpedTraj is StateTraj(lumpedTraj)

# check that immutable
assert traj._trajs[0] is not StateTraj(traj.trajs)._trajs[0]
# check for index trajs
assert (
StateTraj(traj.index_trajs)._trajs[0] is not
StateTraj(traj.index_trajs)._trajs[0]
)


def test_LumpedStateTraj_constructor(macrotraj, statetraj):
"""Test construction of object."""
Expand Down Expand Up @@ -127,6 +135,14 @@ def test_nstates(state_traj, statetraj, macro_traj, macrotraj):
state_traj.nstates = 5


def test_states(state_traj):
"""Test immutability of states property."""
assert state_traj.states is not state_traj.states

with pytest.raises(AttributeError):
state_traj.states = state_traj.states


def test_nframes(state_traj):
"""Test nframes property."""
assert state_traj.nframes == len(state_traj[0])
Expand Down

0 comments on commit 1b3b31d

Please sign in to comment.