Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking: use all sets for training and test #3862

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 24 additions & 37 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
modifier
Data modifier that has the method `modify_data`
trn_all_set
Use all sets as training dataset. Otherwise, if the number of sets is more than 1, the last set is left for test.
[DEPRECATED] Deprecated. Now all sets are trained and tested.
sort_atoms : bool
Sort atoms by atom types. Required to enable when the data is directly feeded to
descriptors except mixed types.
Expand Down Expand Up @@ -109,15 +109,6 @@
# make idx map
self.sort_atoms = sort_atoms
self.idx_map = self._make_idx_map(self.atom_type)
# train dirs
self.test_dir = self.dirs[-1]
if trn_all_set:
self.train_dirs = self.dirs
else:
if len(self.dirs) == 1:
self.train_dirs = self.dirs
else:
self.train_dirs = self.dirs[:-1]
self.data_dict = {}
# add box and coord
self.add("box", 9, must=self.pbc)
Expand Down Expand Up @@ -225,7 +216,7 @@

def check_batch_size(self, batch_size):
"""Check if the system can get a batch of data with `batch_size` frames."""
for ii in self.train_dirs:
for ii in self.dirs:
if self.data_dict["coord"]["high_prec"]:
tmpe = (
(ii / "coord.npy").load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION)
Expand All @@ -240,24 +231,7 @@

def check_test_size(self, test_size):
"""Check if the system can get a test dataset with `test_size` frames."""
if self.data_dict["coord"]["high_prec"]:
tmpe = (
(self.test_dir / "coord.npy")
.load_numpy()
.astype(GLOBAL_ENER_FLOAT_PRECISION)
)
else:
tmpe = (
(self.test_dir / "coord.npy")
.load_numpy()
.astype(GLOBAL_NP_FLOAT_PRECISION)
)
if tmpe.ndim == 1:
tmpe = tmpe.reshape([1, -1])
if tmpe.shape[0] < test_size:
return self.test_dir, tmpe.shape[0]
else:
return None
return self.check_batch_size(test_size)

def get_item_torch(self, index: int) -> dict:
"""Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets.
Expand Down Expand Up @@ -287,7 +261,7 @@
else:
set_size = 0
if self.iterator + batch_size > set_size:
self._load_batch_set(self.train_dirs[self.set_count % self.get_numb_set()])
self._load_batch_set(self.dirs[self.set_count % self.get_numb_set()])
self.set_count += 1
set_size = self.batch_set["coord"].shape[0]
iterator_1 = self.iterator + batch_size
Expand All @@ -307,7 +281,7 @@
Size of the test data set. If `ntests` is -1, all test data will be get.
"""
if not hasattr(self, "test_set"):
self._load_test_set(self.test_dir, self.shuffle_test)
self._load_test_set(self.shuffle_test)
if ntests == -1:
idx = None
else:
Expand Down Expand Up @@ -340,11 +314,11 @@

def get_numb_set(self) -> int:
"""Get number of training sets."""
return len(self.train_dirs)
return len(self.dirs)

def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
"""Get the number of batches in a set."""
data = self._load_set(self.train_dirs[set_idx])
data = self._load_set(self.dirs[set_idx])
ret = data["coord"].shape[0] // batch_size
if ret == 0:
ret = 1
Expand All @@ -353,7 +327,7 @@
def get_sys_numb_batch(self, batch_size: int) -> int:
"""Get the number of batches in the data system."""
ret = 0
for ii in range(len(self.train_dirs)):
for ii in range(len(self.dirs)):
ret += self.get_numb_batch(batch_size, ii)
return ret

Expand Down Expand Up @@ -388,7 +362,7 @@
info = self.data_dict[key]
ndof = info["ndof"]
eners = []
for ii in self.train_dirs:
for ii in self.dirs:

Check warning on line 365 in deepmd/utils/data.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/data.py#L365

Added line #L365 was not covered by tests
data = self._load_set(ii)
ei = data[key].reshape([-1, ndof])
eners.append(ei)
Expand Down Expand Up @@ -441,8 +415,21 @@
def reset_get_batch(self):
self.iterator = 0

def _load_test_set(self, set_name: DPPath, shuffle_test):
self.test_set = self._load_set(set_name)
def _load_test_set(self, shuffle_test: bool):
test_sets = []
for ii in self.dirs:
test_set = self._load_set(ii)
test_sets.append(test_set)
# merge test sets
self.test_set = {}
assert len(test_sets) > 0
for kk in test_sets[0]:
if "find_" in kk:
self.test_set[kk] = test_sets[0][kk]
else:
self.test_set[kk] = np.concatenate(
[test_set[kk] for test_set in test_sets], axis=0
)
if shuffle_test:
self.test_set, _ = self._shuffle_data(self.test_set)

Expand Down
33 changes: 25 additions & 8 deletions source/tests/tf/test_deepmd_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def setUp(self):
path = os.path.join(self.data_name, "set.bar", "test_frame.npy")
self.test_frame_bar = rng.random([self.nframes, 5])
np.save(path, self.test_frame_bar)
path = os.path.join(self.data_name, "set.tar", "test_frame.npy")
self.test_frame_tar = rng.random([2, 5])
np.save(path, self.test_frame_tar)
# t n
self.test_null = np.zeros([self.nframes, 2 * self.natoms])
# tensor shape
Expand All @@ -162,8 +165,9 @@ def test_init(self):
self.assertEqual(dd.idx_map[0], 1)
self.assertEqual(dd.idx_map[1], 0)
self.assertEqual(dd.type_map, ["foo", "bar"])
self.assertEqual(dd.test_dir, "test_data/set.tar")
self.assertEqual(dd.train_dirs, ["test_data/set.bar", "test_data/set.foo"])
self.assertEqual(
dd.dirs, ["test_data/set.bar", "test_data/set.foo", "test_data/set.tar"]
)

def test_init_type_map(self):
dd = DeepmdData(self.data_name, type_map=["bar", "foo", "tar"])
Expand All @@ -182,7 +186,7 @@ def test_load_set(self):
)
data = dd._load_set(os.path.join(self.data_name, "set.foo"))
nframes = data["coord"].shape[0]
self.assertEqual(dd.get_numb_set(), 2)
self.assertEqual(dd.get_numb_set(), 3)
self.assertEqual(dd.get_type_map(), ["foo", "bar"])
self.assertEqual(dd.get_natoms(), 2)
self.assertEqual(list(dd.get_natoms_vec(3)), [2, 2, 1, 1, 0])
Expand Down Expand Up @@ -257,7 +261,10 @@ def test_avg(self):
dd = DeepmdData(self.data_name).add("test_frame", 5, atomic=False, must=True)
favg = dd.avg("test_frame")
fcmp = np.average(
np.concatenate((self.test_frame, self.test_frame_bar), axis=0), axis=0
np.concatenate(
(self.test_frame, self.test_frame_bar, self.test_frame_tar), axis=0
),
axis=0,
)
np.testing.assert_almost_equal(favg, fcmp, places)

Expand All @@ -266,13 +273,17 @@ def test_check_batch_size(self):
ret = dd.check_batch_size(10)
self.assertEqual(ret, (os.path.join(self.data_name, "set.bar"), 5))
ret = dd.check_batch_size(5)
self.assertEqual(ret, (os.path.join(self.data_name, "set.tar"), 2))
ret = dd.check_batch_size(1)
self.assertEqual(ret, None)

def test_check_test_size(self):
dd = DeepmdData(self.data_name)
ret = dd.check_test_size(10)
self.assertEqual(ret, (os.path.join(self.data_name, "set.bar"), 5))
ret = dd.check_test_size(5)
self.assertEqual(ret, (os.path.join(self.data_name, "set.tar"), 2))
ret = dd.check_test_size(2)
ret = dd.check_test_size(1)
self.assertEqual(ret, None)

def test_get_batch(self):
Expand All @@ -284,6 +295,10 @@ def test_get_batch(self):
data = dd.get_batch(5)
self._comp_np_mat2(np.sort(data["coord"], axis=0), np.sort(self.coord, axis=0))
data = dd.get_batch(5)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_tar, axis=0)
)
data = dd.get_batch(5)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_bar, axis=0)
)
Expand All @@ -293,8 +308,11 @@ def test_get_batch(self):
def test_get_test(self):
dd = DeepmdData(self.data_name)
data = dd.get_test()
expected_coord = np.concatenate(
(self.coord_bar, self.coord, self.coord_tar), axis=0
)
self._comp_np_mat2(
np.sort(data["coord"], axis=0), np.sort(self.coord_tar, axis=0)
np.sort(data["coord"], axis=0), np.sort(expected_coord, axis=0)
)

def test_get_nbatch(self):
Expand Down Expand Up @@ -368,8 +386,7 @@ def test_init(self):
dd = DeepmdData(self.data_name)
self.assertEqual(dd.idx_map[0], 0)
self.assertEqual(dd.type_map, ["X"])
self.assertEqual(dd.test_dir, self.data_name + "#/set.000")
self.assertEqual(dd.train_dirs, [self.data_name + "#/set.000"])
self.assertEqual(dd.dirs[0], self.data_name + "#/set.000")

def test_get_batch(self):
dd = DeepmdData(self.data_name)
Expand Down
Loading