From 5c89af182eb8c43de7c2e1f8a19e0423ca0046b9 Mon Sep 17 00:00:00 2001 From: markus583 Date: Wed, 6 Mar 2024 16:03:25 +0000 Subject: [PATCH] optionally, skip eval loss --- wtpsplit/train/adaptertrainer.py | 329 ++++++++++++----------- wtpsplit/train/train_adapter_parallel.py | 2 + 2 files changed, 169 insertions(+), 162 deletions(-) diff --git a/wtpsplit/train/adaptertrainer.py b/wtpsplit/train/adaptertrainer.py index 962cbc52..5303ee48 100644 --- a/wtpsplit/train/adaptertrainer.py +++ b/wtpsplit/train/adaptertrainer.py @@ -118,6 +118,7 @@ def __init__( optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, logging_prefix="", + skip_eval_loss: bool = False, ): super().__init__( model, @@ -134,6 +135,7 @@ def __init__( ) self.logging_prefix = logging_prefix + self.skip_eval_loss = skip_eval_loss if adapter_names is not None: self.model.backbone.set_active_adapters(adapter_names) @@ -803,182 +805,185 @@ def evaluation_loop( Works both with or without labels. """ args = self.args + + if not self.skip_eval_loss: + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train init deepspeed here + if args.deepspeed and not self.deepspeed: + # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval + # from the checkpoint eventually + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine - prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + model = self._wrap_model(self.model, training=False, dataloader=dataloader) - # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init( - self, num_training_steps=0, resume_from_checkpoint=None, inference=True - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) - model = self._wrap_model(self.model, training=False, dataloader=dataloader) + batch_size = self.args.eval_batch_size - # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called - # while ``train`` is running, cast it to the right dtype first and then put on device - if not self.is_in_train: - if args.fp16_full_eval: - model = model.to(dtype=torch.float16, device=args.device) - elif args.bf16_full_eval: - model = model.to(dtype=torch.bfloat16, device=args.device) + logger.info(f"***** Running {description} *****") + if has_length(dataloader): + logger.warning(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") - batch_size = self.args.eval_batch_size + model.eval() - logger.info(f"***** Running {description} *****") - if has_length(dataloader): - logger.warning(f" Num examples = {self.num_examples(dataloader)}") - else: - logger.info(" Num examples: Unknown") - logger.info(f" Batch size = {batch_size}") - - model.eval() - - self.callback_handler.eval_dataloader = dataloader - # Do this before wrapping. - eval_dataset = getattr(dataloader, "dataset", None) - - # MODIFIED: not necessary. - # if is_torch_tpu_available(): - # dataloader = pl.MpDeviceLoader(dataloader, args.device) # .per_device_loader(args.device) - - if args.past_index >= 0: - self._past = None - - # Initialize containers - # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) - losses_host = None - preds_host = None - labels_host = None - inputs_host = None - - # losses/preds/labels on CPU (final containers) - all_losses = None - all_preds = None - all_labels = None - all_inputs = None - # Will be useful when we have an iterable dataset so don't know its length. - - observed_num_examples = 0 - # Main evaluation loop - for step, inputs in enumerate(dataloader): - # Update the observed num examples - observed_batch_size = find_batch_size(inputs) - if observed_batch_size is not None: - observed_num_examples += observed_batch_size - # For batch samplers, batch_size is not known by the dataloader in advance. - if batch_size is None: - batch_size = observed_batch_size - - # Prediction step - loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) # MODIFIED: not necessary. # if is_torch_tpu_available(): - # xm.mark_step() + # dataloader = pl.MpDeviceLoader(dataloader, args.device) # .per_device_loader(args.device) - # Update containers on host - if loss is not None: - # MODIFIED: do not gather across devices. (different loss on each device) - losses = loss.repeat(batch_size) - losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) - if labels is not None: - labels = self._pad_across_processes(labels) - # MODIFIED: do not gather across devices. - # labels = self._nested_gather(labels) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) - if inputs_decode is not None: - inputs_decode = self._pad_across_processes(inputs_decode) - # MODIFIED: do not gather across devices. - # inputs_decode = self._nested_gather(inputs_decode) - inputs_host = ( - inputs_decode - if inputs_host is None - else nested_concat(inputs_host, inputs_decode, padding_index=-100) - ) - if logits is not None: - logits = self._pad_across_processes(logits) - # logits = self._nested_gather(logits) - if self.preprocess_logits_for_metrics is not None: - logits = self.preprocess_logits_for_metrics(logits, labels) - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) - self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) - - # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( + if args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + inputs_host = None + + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + all_inputs = None + # Will be useful when we have an iterable dataset so don't know its length. + + observed_num_examples = 0 + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + # MODIFIED: not necessary. + # if is_torch_tpu_available(): + # xm.mark_step() + + # Update containers on host + if loss is not None: + # MODIFIED: do not gather across devices. (different loss on each device) + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if labels is not None: + labels = self._pad_across_processes(labels) + # MODIFIED: do not gather across devices. + # labels = self._nested_gather(labels) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_decode = self._pad_across_processes(inputs_decode) + # MODIFIED: do not gather across devices. + # inputs_decode = self._nested_gather(inputs_decode) + inputs_host = ( inputs_decode - if all_inputs is None - else nested_concat(all_inputs, inputs_decode, padding_index=-100) + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + if logits is not None: + logits = self._pad_across_processes(logits) + # logits = self._nested_gather(logits) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode + if all_inputs is None + else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Set back to None to begin a new accumulation + losses_host, preds_host, inputs_host, labels_host = ( + None, + None, + None, + None, ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) - - # Set back to None to begin a new accumulation - losses_host, preds_host, inputs_host, labels_host = ( - None, - None, - None, - None, - ) - - if args.past_index and hasattr(self, "_past"): - # Clean the state at the end of the evaluation loop - delattr(self, "_past") - # Gather all remaining tensors and put them back on the CPU - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) - - # Number of samples - if has_length(eval_dataset): - num_samples = len(eval_dataset) - # The instance check is weird and does not actually check for the type, but whether the dataset has the right - # methods. Therefore we need to make sure it also has the attribute. - elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): - num_samples = eval_dataset.num_examples + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"): + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + if all_inputs is not None: + all_inputs = nested_truncate(all_inputs, num_samples) else: - if has_length(dataloader): - num_samples = self.num_examples(dataloader) - else: # both len(dataloader.dataset) and len(dataloader) fail - num_samples = observed_num_examples - - # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of - # samplers has been rounded to a multiple of batch_size, so we truncate. - if all_losses is not None: - all_losses = all_losses[:num_samples] - if all_preds is not None: - all_preds = nested_truncate(all_preds, num_samples) - if all_labels is not None: - all_labels = nested_truncate(all_labels, num_samples) - if all_inputs is not None: - all_inputs = nested_truncate(all_inputs, num_samples) + all_losses, all_preds, all_labels, all_inputs, num_samples = None, None, None, None, 0 # Metrics! # MODIFIED: removed since done in compute_metrics diff --git a/wtpsplit/train/train_adapter_parallel.py b/wtpsplit/train/train_adapter_parallel.py index 02146698..ed715fd2 100644 --- a/wtpsplit/train/train_adapter_parallel.py +++ b/wtpsplit/train/train_adapter_parallel.py @@ -123,6 +123,7 @@ class Args: do_lowercase: bool = False do_remove_punct: bool = False eval_pairwise: bool = False + skip_eval_loss: bool = False def main( @@ -308,6 +309,7 @@ def compute_metrics(trainer): add_lang_ids=False ), logging_prefix=f"{dataset_name}/{lang}/", + skip_eval_loss=args.skip_eval_loss ) if callbacks: trainer.add_callback(callbacks)