Skip to content

Commit

Permalink
optionally, skip eval loss
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Mar 6, 2024
1 parent 4ac568c commit 5c89af1
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 162 deletions.
329 changes: 167 additions & 162 deletions wtpsplit/train/adaptertrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
logging_prefix="",
skip_eval_loss: bool = False,
):
super().__init__(
model,
Expand All @@ -134,6 +135,7 @@ def __init__(
)

self.logging_prefix = logging_prefix
self.skip_eval_loss = skip_eval_loss

if adapter_names is not None:
self.model.backbone.set_active_adapters(adapter_names)
Expand Down Expand Up @@ -803,182 +805,185 @@ def evaluation_loop(
Works both with or without labels.
"""
args = self.args

if not self.skip_eval_loss:
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(
self, num_training_steps=0, resume_from_checkpoint=None, inference=True
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine

prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
model = self._wrap_model(self.model, training=False, dataloader=dataloader)

# if eval is called w/o train init deepspeed here
if args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(
self, num_training_steps=0, resume_from_checkpoint=None, inference=True
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)

model = self._wrap_model(self.model, training=False, dataloader=dataloader)
batch_size = self.args.eval_batch_size

# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
logger.info(f"***** Running {description} *****")
if has_length(dataloader):
logger.warning(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")

batch_size = self.args.eval_batch_size
model.eval()

logger.info(f"***** Running {description} *****")
if has_length(dataloader):
logger.warning(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")

model.eval()

self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)

# MODIFIED: not necessary.
# if is_torch_tpu_available():
# dataloader = pl.MpDeviceLoader(dataloader, args.device) # .per_device_loader(args.device)

if args.past_index >= 0:
self._past = None

# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None

# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length.

observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)

# MODIFIED: not necessary.
# if is_torch_tpu_available():
# xm.mark_step()
# dataloader = pl.MpDeviceLoader(dataloader, args.device) # .per_device_loader(args.device)

# Update containers on host
if loss is not None:
# MODIFIED: do not gather across devices. (different loss on each device)
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if labels is not None:
labels = self._pad_across_processes(labels)
# MODIFIED: do not gather across devices.
# labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None:
inputs_decode = self._pad_across_processes(inputs_decode)
# MODIFIED: do not gather across devices.
# inputs_decode = self._nested_gather(inputs_decode)
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None:
logits = self._pad_across_processes(logits)
# logits = self._nested_gather(logits)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
if args.past_index >= 0:
self._past = None

# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None

# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length.

observed_num_examples = 0
# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None

# MODIFIED: not necessary.
# if is_torch_tpu_available():
# xm.mark_step()

# Update containers on host
if loss is not None:
# MODIFIED: do not gather across devices. (different loss on each device)
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if labels is not None:
labels = self._pad_across_processes(labels)
# MODIFIED: do not gather across devices.
# labels = self._nested_gather(labels)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None:
inputs_decode = self._pad_across_processes(inputs_decode)
# MODIFIED: do not gather across devices.
# inputs_decode = self._nested_gather(inputs_decode)
inputs_host = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
if logits is not None:
logits = self._pad_across_processes(logits)
# logits = self._nested_gather(logits)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = (
None,
None,
None,
None,
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = (
None,
None,
None,
None,
)

if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")

# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
num_samples = eval_dataset.num_examples
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")

# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples

# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
if all_preds is not None:
all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples

# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# samplers has been rounded to a multiple of batch_size, so we truncate.
if all_losses is not None:
all_losses = all_losses[:num_samples]
if all_preds is not None:
all_preds = nested_truncate(all_preds, num_samples)
if all_labels is not None:
all_labels = nested_truncate(all_labels, num_samples)
if all_inputs is not None:
all_inputs = nested_truncate(all_inputs, num_samples)
all_losses, all_preds, all_labels, all_inputs, num_samples = None, None, None, None, 0

# Metrics!
# MODIFIED: removed since done in compute_metrics
Expand Down
2 changes: 2 additions & 0 deletions wtpsplit/train/train_adapter_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class Args:
do_lowercase: bool = False
do_remove_punct: bool = False
eval_pairwise: bool = False
skip_eval_loss: bool = False


def main(
Expand Down Expand Up @@ -308,6 +309,7 @@ def compute_metrics(trainer):
add_lang_ids=False
),
logging_prefix=f"{dataset_name}/{lang}/",
skip_eval_loss=args.skip_eval_loss
)
if callbacks:
trainer.add_callback(callbacks)
Expand Down

0 comments on commit 5c89af1

Please sign in to comment.