From 389d403ef97ba374fb4af5f4256c30c2299cdc95 Mon Sep 17 00:00:00 2001 From: Lysithea <52808607+CaRoLZhangxy@users.noreply.github.com> Date: Tue, 31 Oct 2023 14:55:13 +0800 Subject: [PATCH] merge prob_sys_size with prob_sys_size;0:nsys:1.0 (#2963) to be consistent with Pytorch version --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/utils/data_system.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 69a6cbe112..09dcac2d8d 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -195,8 +195,7 @@ def __init__( assert isinstance(self.test_size, (list, np.ndarray)) assert len(self.test_size) == self.nsystems - # prob of batch, init pick idx - self.prob_nbatches = [float(i) for i in self.nbatches] / np.sum(self.nbatches) + # init pick idx self.pick_idx = 0 # derive system probabilities @@ -350,11 +349,13 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"): if auto_prob_style == "prob_uniform": prob_v = 1.0 / float(self.nsystems) probs = [prob_v for ii in range(self.nsystems)] - elif auto_prob_style == "prob_sys_size": - probs = self.prob_nbatches - elif auto_prob_style[:14] == "prob_sys_size;": + elif auto_prob_style[:13] == "prob_sys_size": + if auto_prob_style == "prob_sys_size": + prob_style = f"prob_sys_size;0:{self.get_nsystems()}:1.0" + else: + prob_style = auto_prob_style probs = prob_sys_size_ext( - auto_prob_style, self.get_nsystems(), self.nbatches + prob_style, self.get_nsystems(), self.nbatches ) else: raise RuntimeError("Unknown auto prob style: " + auto_prob_style)