Skip to content

Commit

Permalink
perf: optimize training loop (deepmodeling#4426)
Browse files Browse the repository at this point in the history
Improvements to the training process:

*
[`deepmd/pt/train/training.py`](diffhunk://#diff-a90c90dc0e6a17fbe2e930f91182805b83260484c9dc1cfac3331378ffa34935R659):
Added a check to skip setting the model to training mode if it already
is. The profiling result shows it takes some time to recursively set it
to all models.

*
[`deepmd/pt/train/training.py`](diffhunk://#diff-a90c90dc0e6a17fbe2e930f91182805b83260484c9dc1cfac3331378ffa34935L686-L690):
Modified the gradient clipping function to include the
`error_if_nonfinite` parameter, and removed the manual check for
non-finite gradients and the associated exception raising.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
	- Improved training loop with enhanced error handling and control flow.
	- Updated gradient clipping logic for better error detection.
	- Refined logging functionality for training and validation results.

- **Bug Fixes**
	- Prevented redundant training calls by adding conditional checks.

- **Documentation**
- Clarified method logic in the `Trainer` class without changing method
signatures.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 037cf3f)
  • Loading branch information
caic99 authored and njzjz committed Dec 22, 2024
1 parent db3d48e commit 69a1628
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,6 @@ def step(_step_id, task_key="Default") -> None:
# PyTorch Profiler
if self.enable_profiler or self.profiling:
prof.step()
self.wrapper.train()
if isinstance(self.lr_exp, dict):
_lr = self.lr_exp[task_key]
else:
Expand All @@ -682,12 +681,11 @@ def step(_step_id, task_key="Default") -> None:
)
loss.backward()
if self.gradient_max_norm > 0.0:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.wrapper.parameters(), self.gradient_max_norm
torch.nn.utils.clip_grad_norm_(
self.wrapper.parameters(),
self.gradient_max_norm,
error_if_nonfinite=True,
)
if not torch.isfinite(grad_norm).all():
# check local gradnorm single GPU case, trigger NanDetector
raise FloatingPointError("gradients are Nan/Inf")
with torch.device("cpu"):
self.optimizer.step()
self.scheduler.step()
Expand Down Expand Up @@ -766,7 +764,7 @@ def fake_model():
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
self.wrapper.eval()
self.wrapper.eval() # Will set to train mode before fininshing validation

def log_loss_train(_loss, _more_loss, _task_key="Default"):
results = {}
Expand Down Expand Up @@ -872,6 +870,7 @@ def log_loss_valid(_task_key="Default"):
learning_rate=None,
)
)
self.wrapper.train()

current_time = time.time()
train_time = current_time - self.t0
Expand Down Expand Up @@ -927,6 +926,7 @@ def log_loss_valid(_task_key="Default"):
f"{task_key}/{item}", more_loss[item], display_step_id
)

self.wrapper.train()
self.t0 = time.time()
self.total_train_time = 0.0
for step_id in range(self.num_steps):
Expand Down

0 comments on commit 69a1628

Please sign in to comment.