Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify dataset construction #4437

Merged
merged 14 commits into from
Dec 9, 2024
Merged
57 changes: 16 additions & 41 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,47 +1053,22 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
checkpoint_files[0].unlink()

def get_data(self, is_train=True, task_key="Default"):
if not self.multi_task:
if is_train:
try:
batch_data = next(iter(self.training_data))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
with torch.device("cpu"):
caic99 marked this conversation as resolved.
Show resolved Hide resolved
self.training_data = BufferedIterator(
iter(self.training_dataloader)
)
batch_data = next(iter(self.training_data))
else:
if self.validation_data is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data))
except StopIteration:
self.validation_data = BufferedIterator(
iter(self.validation_dataloader)
)
batch_data = next(iter(self.validation_data))
else:
if is_train:
try:
batch_data = next(iter(self.training_data[task_key]))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
self.training_data[task_key] = BufferedIterator(
iter(self.training_dataloader[task_key])
)
batch_data = next(iter(self.training_data[task_key]))
else:
if self.validation_data[task_key] is None:
return {}, {}, {}
try:
batch_data = next(iter(self.validation_data[task_key]))
except StopIteration:
self.validation_data[task_key] = BufferedIterator(
iter(self.validation_dataloader[task_key])
)
batch_data = next(iter(self.validation_data[task_key]))
data, dataloader = (
(self.training_data, self.training_dataloader)
if is_train
else (self.validation_data, self.validation_dataloader)
)
if data is None and not is_train:
return {}, {}, {}
if self.multi_task:
data = data[task_key]
dataloader = dataloader[task_key]
caic99 marked this conversation as resolved.
Show resolved Hide resolved
try:
batch_data = next(iter(data))
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
data = BufferedIterator(iter(dataloader))
caic99 marked this conversation as resolved.
Show resolved Hide resolved
batch_data = next(iter(data))

for key in batch_data.keys():
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
Expand Down
110 changes: 49 additions & 61 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import queue
import time
from multiprocessing.dummy import (
from functools import (
partial,
)
from multiprocessing import (
Pool,
)
from queue import (
Queue,
)
from threading import (
Thread,
)
Expand Down Expand Up @@ -52,6 +57,13 @@ def setup_seed(seed) -> None:
torch.backends.cudnn.deterministic = True


def construct_dataset(system, type_map):
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)


class DpLoaderSet(Dataset):
"""A dataset for storing DataLoaders to multiple Systems.

Expand Down Expand Up @@ -87,11 +99,7 @@ def __init__(
if len(systems) >= 100:
log.info(f"Constructing DataLoaders from {len(systems)} systems")

def construct_dataset(system):
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)
construct_dataset_systems = partial(construct_dataset, type_map=type_map)

with Pool(
os.cpu_count()
Expand All @@ -101,7 +109,7 @@ def construct_dataset(system):
else 1
)
) as pool:
self.systems = pool.map(construct_dataset, systems)
self.systems = pool.map(construct_dataset_systems, systems)

self.sampler_list: list[DistributedSampler] = []
self.index = []
Expand Down Expand Up @@ -185,85 +193,65 @@ def print_summary(
name: str,
prob: list[float],
) -> None:
print_summary(
name,
len(self.systems),
[ss.system for ss in self.systems],
[ss._natoms for ss in self.systems],
self.batch_sizes,
[
ss._data_system.get_sys_numb_batch(self.batch_sizes[ii])
for ii, ss in enumerate(self.systems)
],
prob,
[ss._data_system.pbc for ss in self.systems],
)


_sentinel = object()
QUEUESIZE = 32
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
print_summary(
name,
len(self.systems),
[ss.system for ss in self.systems],
[ss._natoms for ss in self.systems],
self.batch_sizes,
[
ss._data_system.get_sys_numb_batch(self.batch_sizes[ii])
for ii, ss in enumerate(self.systems)
],
prob,
[ss._data_system.pbc for ss in self.systems],
)


class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len) -> None:
Thread.__init__(self)
def __init__(self, queue, source) -> None:
super().__init__()
self.daemon = True
self._queue = queue
self._source = source # Main DL iterator
self._max_len = max_len #

def run(self) -> None:
for item in self._source:
self._queue.put(item) # Blocking if the queue is full

# Signal the consumer we are done.
self._queue.put(_sentinel)
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration())

caic99 marked this conversation as resolved.
Show resolved Hide resolved

QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._queue = Queue(QUEUESIZE)
self._iterable = iterable
self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
self._consumer.daemon = True
self._consumer = BackgroundConsumer(self._queue, self._iterable)
self._consumer.start()
self.len = len(iterable)

def __iter__(self):
return self

def __len__(self) -> int:
return self.total
return self.len

def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
log.warning(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if (
wait_time > 1.0
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(f"Data loading is slow, waited {wait_time:.2f} seconds.")
caic99 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration
return item


Expand Down
3 changes: 3 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __new__(cls, path: str, mode: str = "r"):
raise FileNotFoundError(f"{path} not found")
return super().__new__(cls)

def __getnewargs__(self):
return (self.path, self.mode)
caic99 marked this conversation as resolved.
Show resolved Hide resolved

caic99 marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down