Skip to content

Commit

Permalink
Update training.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 29, 2024
1 parent 35598d2 commit c0a0cfc
Showing 1 changed file with 30 additions and 46 deletions.
76 changes: 30 additions & 46 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,61 +128,45 @@ def get_opt_param(params):
return opt_type, opt_param

def get_data_loader(_training_data, _validation_data, _training_params):
if "auto_prob" in _training_params["training_data"]:
train_sampler = get_weighted_sampler(
_training_data, _training_params["training_data"]["auto_prob"]
)
elif "sys_probs" in _training_params["training_data"]:
train_sampler = get_weighted_sampler(
_training_data,
_training_params["training_data"]["sys_probs"],
sys_prob=True,
)
else:
train_sampler = get_weighted_sampler(_training_data, "prob_sys_size")

if train_sampler is None:
log.warning(
"Sampler not specified!"
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
training_dataloader = DataLoader(
_training_data,
sampler=train_sampler,
batch_size=None,
num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
pin_memory=True,
)
with torch.device("cpu"):
training_data_buffered = BufferedIterator(iter(training_dataloader))
if _validation_data is not None:
if "auto_prob" in _training_params["validation_data"]:
valid_sampler = get_weighted_sampler(
_validation_data,
_training_params["validation_data"]["auto_prob"],
def get_dataloader_and_buffer(_data, _params):
if "auto_prob" in _training_params["training_data"]:
_sampler = get_weighted_sampler(
_data, _params["training_data"]["auto_prob"]
)
elif "sys_probs" in _training_params["validation_data"]:
valid_sampler = get_weighted_sampler(
_validation_data,
_training_params["validation_data"]["sys_probs"],
elif "sys_probs" in _training_params["training_data"]:
_sampler = get_weighted_sampler(
_data,
_params["training_data"]["sys_probs"],
sys_prob=True,
)
else:
valid_sampler = get_weighted_sampler(
_validation_data, "prob_sys_size"
)
validation_dataloader = DataLoader(
_validation_data,
sampler=valid_sampler,
_sampler = get_weighted_sampler(_data, "prob_sys_size")

if _sampler is None:
log.warning(
"Sampler not specified!"
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
_dataloader = DataLoader(
_data,
sampler=_sampler,
batch_size=None,
num_workers=min(NUM_WORKERS, 1),
num_workers=NUM_WORKERS, # setting to 0 diverges the behavior of its iterator; should be >=1
drop_last=False,
pin_memory=True,
)
with torch.device("cpu"):
validation_data_buffered = BufferedIterator(
iter(validation_dataloader)
)
_data_buffered = BufferedIterator(iter(_dataloader))
return _dataloader, _data_buffered

training_dataloader, training_data_buffered = get_dataloader_and_buffer(
_training_data, _training_params
)

if _validation_data is not None:
(
validation_dataloader,
validation_data_buffered,
) = get_dataloader_and_buffer(_validation_data, _training_params)
if _training_params.get("validation_data", None) is not None:
valid_numb_batch = _training_params["validation_data"].get(
"numb_btch", 1
Expand Down

0 comments on commit c0a0cfc

Please sign in to comment.