From bf4b473bef9cbfa69b0b201312fa521f539e825d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 2 Mar 2024 02:42:08 -0500 Subject: [PATCH] pt: print data summary (#3383) Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 23 +++++++ deepmd/pt/utils/dataloader.py | 22 +++++++ deepmd/pt/utils/dataset.py | 1 + deepmd/utils/data_system.py | 112 +++++++++++++++++++++++----------- 4 files changed, 124 insertions(+), 34 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 07c8511cfe..97da0ce322 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -53,6 +53,9 @@ from deepmd.pt.utils.stat import ( make_stat_input, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) if torch.__version__.startswith("2"): import torch._dynamo @@ -288,6 +291,14 @@ def get_loss(loss_params, start_lr, _ntypes): self.validation_data, self.valid_numb_batch, ) = get_data_loader(training_data, validation_data, training_params) + training_data.print_summary( + "training", to_numpy_array(self.training_dataloader.sampler.weights) + ) + if validation_data is not None: + validation_data.print_summary( + "validation", + to_numpy_array(self.validation_dataloader.sampler.weights), + ) else: ( self.training_dataloader, @@ -317,6 +328,18 @@ def get_loss(loss_params, start_lr, _ntypes): training_params["data_dict"][model_key], ) + training_data[model_key].print_summary( + f"training in {model_key}", + to_numpy_array(self.training_dataloader[model_key].sampler.weights), + ) + if validation_data is not None: + validation_data[model_key].print_summary( + f"validation in {model_key}", + to_numpy_array( + self.validation_dataloader[model_key].sampler.weights + ), + ) + # Learning rate self.warmup_steps = training_params.get("warmup_steps", 0) self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 65a96418c9..2715bced52 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -39,6 +39,7 @@ DataRequirementItem, ) from deepmd.utils.data_system import ( + print_summary, prob_sys_size_ext, process_sys_probs, ) @@ -91,6 +92,7 @@ def construct_dataset(system): self.total_batch = 0 self.dataloaders = [] + self.batch_sizes = [] for system in self.systems: if dist.is_initialized(): system_sampler = DistributedSampler(system) @@ -110,6 +112,7 @@ def construct_dataset(system): self.batch_size += 1 else: self.batch_size = batch_size + self.batch_sizes.append(self.batch_size) system_dataloader = DataLoader( dataset=system, batch_size=self.batch_size, @@ -155,6 +158,25 @@ def add_data_requirement(self, data_requirement: List[DataRequirementItem]): for system in self.systems: system.add_data_requirement(data_requirement) + def print_summary( + self, + name: str, + prob: List[float], + ): + print_summary( + name, + len(self.systems), + [ss.system for ss in self.systems], + [ss._natoms for ss in self.systems], + self.batch_sizes, + [ + ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) + for ii, ss in enumerate(self.systems) + ], + prob, + [ss._data_system.pbc for ss in self.systems], + ) + _sentinel = object() QUEUESIZE = 32 diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 40a513acdf..67005b5ed3 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -29,6 +29,7 @@ def __init__( - batch_size: Max frame count in a batch. - type_map: Atom types. """ + self.system = system self._type_map = type_map self._data_system = DeepmdData( sys_path=system, shuffle_test=shuffle, type_map=self._type_map diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 90b600548f..ba1041f113 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -556,40 +556,16 @@ def get_batch_size(self) -> int: """Get the batch size.""" return self.batch_size - def _format_name_length(self, name, width): - if len(name) <= width: - return "{: >{}}".format(name, width) - else: - name = name[-(width - 3) :] - name = "-- " + name - return name - - def print_summary(self, name): - # width 65 - sys_width = 42 - log.info( - f"---Summary of DataSystem: {name:13s}-----------------------------------------------" - ) - log.info("found %d system(s):" % self.nsystems) - log.info( - ("%s " % self._format_name_length("system", sys_width)) - + ("%6s %6s %6s %9s %3s" % ("natoms", "bch_sz", "n_bch", "prob", "pbc")) - ) - for ii in range(self.nsystems): - log.info( - "%s %6d %6d %6d %9.3e %3s" - % ( - self._format_name_length(self.system_dirs[ii], sys_width), - self.natoms[ii], - # TODO batch size * nbatches = number of structures - self.batch_size[ii], - self.nbatches[ii], - self.sys_probs[ii], - "T" if self.data_systems[ii].pbc else "F", - ) - ) - log.info( - "--------------------------------------------------------------------------------------" + def print_summary(self, name: str): + print_summary( + name, + self.nsystems, + self.system_dirs, + self.natoms, + self.batch_size, + self.nbatches, + self.sys_probs, + [ii.pbc for ii in self.data_systems], ) def _make_auto_bs(self, rule): @@ -625,6 +601,74 @@ def _check_type_map_consistency(self, type_map_list): return ret +def _format_name_length(name, width): + if len(name) <= width: + return "{: >{}}".format(name, width) + else: + name = name[-(width - 3) :] + name = "-- " + name + return name + + +def print_summary( + name: str, + nsystems: int, + system_dirs: List[str], + natoms: List[int], + batch_size: List[int], + nbatches: List[int], + sys_probs: List[float], + pbc: List[bool], +): + """Print summary of systems. + + Parameters + ---------- + name : str + The name of the system + nsystems : int + The number of systems + system_dirs : list of str + The directories of the systems + natoms : list of int + The number of atoms + batch_size : list of int + The batch size + nbatches : list of int + The number of batches + sys_probs : list of float + The probabilities + pbc : list of bool + The periodic boundary conditions + """ + # width 65 + sys_width = 42 + log.info( + f"---Summary of DataSystem: {name:13s}-----------------------------------------------" + ) + log.info("found %d system(s):" % nsystems) + log.info( + ("%s " % _format_name_length("system", sys_width)) + + ("%6s %6s %6s %9s %3s" % ("natoms", "bch_sz", "n_bch", "prob", "pbc")) + ) + for ii in range(nsystems): + log.info( + "%s %6d %6d %6d %9.3e %3s" + % ( + _format_name_length(system_dirs[ii], sys_width), + natoms[ii], + # TODO batch size * nbatches = number of structures + batch_size[ii], + nbatches[ii], + sys_probs[ii], + "T" if pbc[ii] else "F", + ) + ) + log.info( + "--------------------------------------------------------------------------------------" + ) + + def process_sys_probs(sys_probs, nbatch): sys_probs = np.array(sys_probs) type_filter = sys_probs >= 0