From e5b50e414dedef052238b978552220e35e3ba1ac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 1 Mar 2024 23:42:22 -0500 Subject: [PATCH] pt: expand systems before training Signed-off-by: Jinzhe Zeng --- deepmd/pt/entrypoints/main.py | 6 ++++ deepmd/utils/data_system.py | 55 ++++++++++++++++++++++++----------- deepmd/utils/path.py | 2 ++ 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 736e8dde09..0e5767cb4e 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -59,6 +59,9 @@ from deepmd.utils.compat import ( update_deepmd_input, ) +from deepmd.utils.data_system import ( + process_systems, +) from deepmd.utils.path import ( DPPath, ) @@ -108,6 +111,9 @@ def prepare_trainer_input_single( validation_dataset_params["systems"] if validation_dataset_params else None ) training_systems = training_dataset_params["systems"] + training_systems = process_systems(training_systems) + if validation_systems is not None: + validation_systems = process_systems(validation_systems) # stat files stat_file_path_single = data_dict_single.get("stat_file", None) diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 90b600548f..347b0a14f6 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -10,6 +10,7 @@ Dict, List, Optional, + Union, ) import numpy as np @@ -667,30 +668,22 @@ def prob_sys_size_ext(keywords, nsystems, nbatch): return sys_probs -def get_data( - jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False -) -> DeepmdDataSystem: - """Get the data system. +def process_systems(systems: Union[str, List[str]]) -> List[str]: + """Process the user-input systems. + + If it is a single directory, search for all the systems in the directory. + Check if the systems are valid. Parameters ---------- - jdata - The json data - rcut - The cut-off radius, not used - type_map - The type map - modifier - The data modifier - multi_task_mode - If in multi task mode + systems : str or list of str + The user-input systems Returns ------- - DeepmdDataSystem - The data system + list of str + The valid systems """ - systems = j_must_have(jdata, "systems") if isinstance(systems, str): systems = expand_sys_str(systems) elif isinstance(systems, list): @@ -712,6 +705,34 @@ def get_data( msg = f"dir {ii} is not a valid data system dir" log.fatal(msg) raise OSError(msg, help_msg) + return systems + + +def get_data( + jdata: Dict[str, Any], rcut, type_map, modifier, multi_task_mode=False +) -> DeepmdDataSystem: + """Get the data system. + + Parameters + ---------- + jdata + The json data + rcut + The cut-off radius, not used + type_map + The type map + modifier + The data modifier + multi_task_mode + If in multi task mode + + Returns + ------- + DeepmdDataSystem + The data system + """ + systems = j_must_have(jdata, "systems") + systems = process_systems(systems) batch_size = j_must_have(jdata, "batch_size") sys_probs = jdata.get("sys_probs", None) diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 79361b6c23..5887e91850 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -414,6 +414,8 @@ def is_file(self) -> bool: def is_dir(self) -> bool: """Check if self is directory.""" + if self._name == "/": + return True if self._name not in self._keys: return False return isinstance(self.root[self._name], h5py.Group)