Skip to content

Commit

Permalink
Try to fix training Loss inconsistent after resume from old checkpoint (
Browse files Browse the repository at this point in the history
huggingface#25872)

* fix loss inconsistent after resume  huggingface#25340

* fix typo

* clean code

* reformatted code

* adjust code according to comments

* adjust check_dataloader_randomsampler location

* return sampler only

* handle sampler is None

* Update src/transformers/trainer_pt_utils.py

thanks @amyeroberts

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
dumpmemory and amyeroberts authored Sep 7, 2023
1 parent c5e66a4 commit fb7d246
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand All @@ -85,6 +85,7 @@
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_dataloader_sampler,
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
Expand Down Expand Up @@ -219,6 +220,7 @@
if TYPE_CHECKING:
import optuna


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -1808,8 +1810,17 @@ def _inner_training_loop(
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
for _ in train_dataloader:
break
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)

total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
logger = logging.get_logger(__name__)


def get_dataloader_sampler(dataloader):
if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None:
return get_dataloader_sampler(dataloader.batch_sampler)
elif hasattr(dataloader, "sampler"):
return dataloader.sampler


def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
if isinstance(tensor_or_array, torch.Tensor):
if hasattr(torch, "atleast_1d"):
Expand Down

0 comments on commit fb7d246

Please sign in to comment.