Skip to content

Commit

Permalink
Start with eval (#86)
Browse files Browse the repository at this point in the history
* Add start with eval option

* Ping for training

* Drop p3.6 from CI

* Turn off telemetry on CI

* Bump up to v0.0.18

* Fix mixedprecision
  • Loading branch information
erogol authored Dec 13, 2022
1 parent 0c08730 commit d3e5988
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,16 @@ def _model_train_step(
return model.module.train_step(*input_args)
return model.train_step(*input_args)

def _get_autocast_args(self, mixed_precision: bool):
device = "cpu"
dtype = None
if self.use_cuda:
device = "cuda"
dtype = torch.float16 if mixed_precision else torch.float32
elif mixed_precision:
dtype = torch.bfloat16
return device, dtype

def _optimize(
self,
batch: Dict,
Expand Down Expand Up @@ -1007,7 +1017,10 @@ def _optimize(
step_start_time = time.time()

# forward pass and loss computation
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
device, dtype = self._get_autocast_args(config.mixed_precision)
with torch.autocast(
device_type=device, dtype=dtype, enabled=config.mixed_precision
):
if optimizer_idx is not None:
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
else:
Expand Down

0 comments on commit d3e5988

Please sign in to comment.