Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Now multiple checkpoints will be saved after using -stopk option in train_model.py #4978

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
78 changes: 71 additions & 7 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
import json
import os
from pathlib import Path
import signal
from typing import Tuple

Expand Down Expand Up @@ -58,7 +59,7 @@
num_workers,
)
from parlai.utils.io import PathManager
from parlai.utils.misc import Timer, nice_report
from parlai.utils.misc import Timer, nice_report, ordinal
from parlai.utils.world_logging import WorldLogger


Expand Down Expand Up @@ -134,6 +135,13 @@ def setup_args(parser=None) -> ParlaiParser:
default=-1,
help='End training after n model updates',
)
train.add_argument(
'-stopk',
'--save-top-k-checkpoints',
type=int,
default=1,
help='Save and keep checkpoints with top k validation metric',
)
train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1)
train.add_argument(
'-lstep',
Expand Down Expand Up @@ -409,6 +417,7 @@ def __init__(self, opt):
self.save_every_n_secs = _num_else_inf(
opt, 'save_every_n_secs', distributed_warn=True
)
self.save_top_k = _num_else_inf(opt, 'save_top_k_checkpoints')

# smart defaults for --validation-metric-mode
if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}:
Expand All @@ -427,6 +436,7 @@ def __init__(self, opt):
self.final_test_report = {}
self.final_extra_valid_report = {}
self.best_valid = None
self.best_k_models = [] #every element is a (checkpoint address, validation metric) tuple

self.impatience = 0
self.saved = False
Expand All @@ -453,6 +463,7 @@ def __init__(self, opt):
'total_epochs', 0.0
)
self.train_reports = obj.get('train_reports', [])
self.best_k_models = obj.get('best_k_models', {})
if 'best_valid' in obj:
self.best_valid = obj['best_valid']
else:
Expand Down Expand Up @@ -522,6 +533,7 @@ def _save_train_stats(self, suffix=None):
'train_reports': self.train_reports,
'valid_reports': self.valid_reports,
'best_valid': self.best_valid,
'best_k_models': self.best_k_models,
'impatience': self.impatience,
'final_valid_report': dict_report(self.final_valid_report),
'final_test_report': dict_report(self.final_test_report),
Expand Down Expand Up @@ -579,8 +591,10 @@ def validate(self):

# check if this is the best validation so far
if (
self.best_valid is None
or self.valid_optim * new_valid > self.valid_optim * self.best_valid
(
self.best_valid is None
or self.valid_optim * new_valid > self.valid_optim * self.best_valid
) and self.save_top_k == 1
):
logging.success(
'new best {}: {:.4g}{}'.format(
Expand All @@ -606,13 +620,45 @@ def validate(self):
):
logging.info('task solved! stopping.')
return True
elif (
self.save_top_k > 1
and self.opt.get('model_file')
and (
len(self.best_k_models) < self.save_top_k
or self.valid_optim * new_valid > self.valid_optim * self.best_k_models[-1][1]
# if new validation metric is better than kth saved model metric
)
):
self.impatience = 0
model_rank = sum(new_valid < saved_model_prop[1] for saved_model_prop in self.best_k_models)
model_suffix = '_' + ordinal(model_rank+1) + '.' + str(self._train_steps)
self.best_k_models.insert(model_rank, [self.opt['model_file']+model_suffix, new_valid])
self.save_model(model_suffix) # Save model as "model_nth.<number_of_train_steps>"
self.saved = True
self._modify_next_rank_checkpoints(model_rank)
if (
opt['validation_metric_mode'] == 'max'
and self.best_k_models[-1][1] >= opt['validation_cutoff']
) or (
opt['validation_metric_mode'] == 'min'
and self.best_k_models[-1][1] <= opt['validation_cutoff']
):
logging.info('task solved! stopping.')
return True
Comment on lines +647 to +655
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this check requires that we look at self.best_k_models[0], right? since we're looking at the best metric

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea here was that we want the last saved model's metric to be better than validation_cutoff so I checked the last one with self.best_k_models[-1]

else:
self.impatience += 1
logging.report(
'did not beat best {}: {} impatience: {}'.format(
opt['validation_metric'], round(self.best_valid, 4), self.impatience
if self.save_top_k == 1:
logging.report(
'did not beat best {}: {} impatience: {}'.format(
opt['validation_metric'], round(self.best_valid, 4), self.impatience
)
)
else:
logging.report(
'did not beat {} model\'s {}: {} impatience: {}'.format(
ordinal(self.save_top_k), opt['validation_metric'], round(self.best_k_models[-1][1], 4), self.impatience
)
)
)
self.validate_time.reset()

# saving
Expand All @@ -628,6 +674,24 @@ def validate(self):
logging.info('ran out of patience! stopping training.')
return True
return False



def _modify_next_rank_checkpoints(self, model_rank):
if len(self.best_k_models) > self.save_top_k:
#remove last best model and its files from disk and best_k_models list to make space for new model
last_path = Path(self.best_k_models[-1][0])
for file in last_path.parent.glob(last_path.name + '*'):
file.unlink()
del self.best_k_models[-1]
for ind in range(model_rank+1, len(self.best_k_models)):
prev_model_path = Path(self.best_k_models[ind][0])
model_train_steps = prev_model_path.suffix[1:]
new_model_path = Path(self.opt['model_file'] + '_' + ordinal(ind+1) + '.' + model_train_steps)
for file in prev_model_path.parent.glob(prev_model_path.name + '*'):
file.rename(str(new_model_path) + ''.join(file.suffixes[1:]))
self.best_k_models[ind][0] = str(new_model_path)


def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

Expand Down
10 changes: 10 additions & 0 deletions parlai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,16 @@ def _report_sort_key(report_key: str) -> Tuple[str, str]:
sub_key = '/'.join(fields)
return (sub_key or 'all', main_key)

def ordinal(n: int):
"""
Convert a number to its ordinal counterpart
"""
if 11 <= (n % 100) <= 13:
suffix = 'th'
else:
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
return str(n) + suffix


def float_formatter(f: Union[float, int]) -> str:
"""
Expand Down