Skip to content

Commit

Permalink
pt: print data summary (#3383)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Mar 2, 2024
1 parent f4abe12 commit bf4b473
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 34 deletions.
23 changes: 23 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)

if torch.__version__.startswith("2"):
import torch._dynamo
Expand Down Expand Up @@ -288,6 +291,14 @@ def get_loss(loss_params, start_lr, _ntypes):
self.validation_data,
self.valid_numb_batch,
) = get_data_loader(training_data, validation_data, training_params)
training_data.print_summary(
"training", to_numpy_array(self.training_dataloader.sampler.weights)
)
if validation_data is not None:
validation_data.print_summary(
"validation",
to_numpy_array(self.validation_dataloader.sampler.weights),
)
else:
(
self.training_dataloader,
Expand Down Expand Up @@ -317,6 +328,18 @@ def get_loss(loss_params, start_lr, _ntypes):
training_params["data_dict"][model_key],
)

training_data[model_key].print_summary(
f"training in {model_key}",
to_numpy_array(self.training_dataloader[model_key].sampler.weights),
)
if validation_data is not None:
validation_data[model_key].print_summary(
f"validation in {model_key}",
to_numpy_array(
self.validation_dataloader[model_key].sampler.weights
),
)

# Learning rate
self.warmup_steps = training_params.get("warmup_steps", 0)
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
Expand Down
22 changes: 22 additions & 0 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DataRequirementItem,
)
from deepmd.utils.data_system import (
print_summary,
prob_sys_size_ext,
process_sys_probs,
)
Expand Down Expand Up @@ -91,6 +92,7 @@ def construct_dataset(system):
self.total_batch = 0

self.dataloaders = []
self.batch_sizes = []
for system in self.systems:
if dist.is_initialized():
system_sampler = DistributedSampler(system)
Expand All @@ -110,6 +112,7 @@ def construct_dataset(system):
self.batch_size += 1
else:
self.batch_size = batch_size
self.batch_sizes.append(self.batch_size)
system_dataloader = DataLoader(
dataset=system,
batch_size=self.batch_size,
Expand Down Expand Up @@ -155,6 +158,25 @@ def add_data_requirement(self, data_requirement: List[DataRequirementItem]):
for system in self.systems:
system.add_data_requirement(data_requirement)

def print_summary(
self,
name: str,
prob: List[float],
):
print_summary(
name,
len(self.systems),
[ss.system for ss in self.systems],
[ss._natoms for ss in self.systems],
self.batch_sizes,
[
ss._data_system.get_sys_numb_batch(self.batch_sizes[ii])
for ii, ss in enumerate(self.systems)
],
prob,
[ss._data_system.pbc for ss in self.systems],
)


_sentinel = object()
QUEUESIZE = 32
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
- batch_size: Max frame count in a batch.
- type_map: Atom types.
"""
self.system = system
self._type_map = type_map
self._data_system = DeepmdData(
sys_path=system, shuffle_test=shuffle, type_map=self._type_map
Expand Down
112 changes: 78 additions & 34 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,40 +556,16 @@ def get_batch_size(self) -> int:
"""Get the batch size."""
return self.batch_size

def _format_name_length(self, name, width):
if len(name) <= width:
return "{: >{}}".format(name, width)
else:
name = name[-(width - 3) :]
name = "-- " + name
return name

def print_summary(self, name):
# width 65
sys_width = 42
log.info(
f"---Summary of DataSystem: {name:13s}-----------------------------------------------"
)
log.info("found %d system(s):" % self.nsystems)
log.info(
("%s " % self._format_name_length("system", sys_width))
+ ("%6s %6s %6s %9s %3s" % ("natoms", "bch_sz", "n_bch", "prob", "pbc"))
)
for ii in range(self.nsystems):
log.info(
"%s %6d %6d %6d %9.3e %3s"
% (
self._format_name_length(self.system_dirs[ii], sys_width),
self.natoms[ii],
# TODO batch size * nbatches = number of structures
self.batch_size[ii],
self.nbatches[ii],
self.sys_probs[ii],
"T" if self.data_systems[ii].pbc else "F",
)
)
log.info(
"--------------------------------------------------------------------------------------"
def print_summary(self, name: str):
print_summary(
name,
self.nsystems,
self.system_dirs,
self.natoms,
self.batch_size,
self.nbatches,
self.sys_probs,
[ii.pbc for ii in self.data_systems],
)

def _make_auto_bs(self, rule):
Expand Down Expand Up @@ -625,6 +601,74 @@ def _check_type_map_consistency(self, type_map_list):
return ret


def _format_name_length(name, width):
if len(name) <= width:
return "{: >{}}".format(name, width)
else:
name = name[-(width - 3) :]
name = "-- " + name
return name


def print_summary(
name: str,
nsystems: int,
system_dirs: List[str],
natoms: List[int],
batch_size: List[int],
nbatches: List[int],
sys_probs: List[float],
pbc: List[bool],
):
"""Print summary of systems.
Parameters
----------
name : str
The name of the system
nsystems : int
The number of systems
system_dirs : list of str
The directories of the systems
natoms : list of int
The number of atoms
batch_size : list of int
The batch size
nbatches : list of int
The number of batches
sys_probs : list of float
The probabilities
pbc : list of bool
The periodic boundary conditions
"""
# width 65
sys_width = 42
log.info(
f"---Summary of DataSystem: {name:13s}-----------------------------------------------"
)
log.info("found %d system(s):" % nsystems)
log.info(
("%s " % _format_name_length("system", sys_width))
+ ("%6s %6s %6s %9s %3s" % ("natoms", "bch_sz", "n_bch", "prob", "pbc"))
)
for ii in range(nsystems):
log.info(
"%s %6d %6d %6d %9.3e %3s"
% (
_format_name_length(system_dirs[ii], sys_width),
natoms[ii],
# TODO batch size * nbatches = number of structures
batch_size[ii],
nbatches[ii],
sys_probs[ii],
"T" if pbc[ii] else "F",
)
)
log.info(
"--------------------------------------------------------------------------------------"
)


def process_sys_probs(sys_probs, nbatch):
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
Expand Down

0 comments on commit bf4b473

Please sign in to comment.