diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 24b77f162aa..5724e76b4df 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -28,6 +28,7 @@ import torch import json import os +from pathlib import Path import signal from typing import Tuple @@ -58,7 +59,7 @@ num_workers, ) from parlai.utils.io import PathManager -from parlai.utils.misc import Timer, nice_report +from parlai.utils.misc import Timer, nice_report, ordinal from parlai.utils.world_logging import WorldLogger @@ -134,6 +135,13 @@ def setup_args(parser=None) -> ParlaiParser: default=-1, help='End training after n model updates', ) + train.add_argument( + '-stopk', + '--save-top-k-checkpoints', + type=int, + default=1, + help='Save and keep checkpoints with top k validation metric', + ) train.add_argument('-ltim', '--log-every-n-secs', type=float, default=-1) train.add_argument( '-lstep', @@ -410,6 +418,7 @@ def __init__(self, opt): self.save_every_n_secs = _num_else_inf( opt, 'save_every_n_secs', distributed_warn=True ) + self.save_top_k = _num_else_inf(opt, 'save_top_k_checkpoints') # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: @@ -428,6 +437,7 @@ def __init__(self, opt): self.final_test_report = {} self.final_extra_valid_report = {} self.best_valid = None + self.best_k_models = [] #every element is a (checkpoint address, validation metric) tuple self.impatience = 0 self.saved = False @@ -454,6 +464,7 @@ def __init__(self, opt): 'total_epochs', 0.0 ) self.train_reports = obj.get('train_reports', []) + self.best_k_models = obj.get('best_k_models', {}) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: @@ -525,6 +536,7 @@ def _save_train_stats(self, suffix=None): 'train_reports': self.train_reports, 'valid_reports': self.valid_reports, 'best_valid': self.best_valid, + 'best_k_models': self.best_k_models, 'impatience': self.impatience, 'final_valid_report': dict_report(self.final_valid_report), 'final_test_report': dict_report(self.final_test_report), @@ -587,8 +599,10 @@ def validate(self): # check if this is the best validation so far if ( - self.best_valid is None - or self.valid_optim * new_valid > self.valid_optim * self.best_valid + ( + self.best_valid is None + or self.valid_optim * new_valid > self.valid_optim * self.best_valid + ) and self.save_top_k == 1 ): logging.success( 'new best {}: {:.4g}{}'.format( @@ -614,13 +628,45 @@ def validate(self): ): logging.info('task solved! stopping.') return True + elif ( + self.save_top_k > 1 + and self.opt.get('model_file') + and ( + len(self.best_k_models) < self.save_top_k + or self.valid_optim * new_valid > self.valid_optim * self.best_k_models[-1][1] + # if new validation metric is better than kth saved model metric + ) + ): + self.impatience = 0 + model_rank = sum(new_valid < saved_model_prop[1] for saved_model_prop in self.best_k_models) + model_suffix = '_' + ordinal(model_rank+1) + '.' + str(self._train_steps) + self.best_k_models.insert(model_rank, [self.opt['model_file']+model_suffix, new_valid]) + self.save_model(model_suffix) # Save model as "model_nth." + self.saved = True + self._modify_next_rank_checkpoints(model_rank) + if ( + opt['validation_metric_mode'] == 'max' + and self.best_k_models[-1][1] >= opt['validation_cutoff'] + ) or ( + opt['validation_metric_mode'] == 'min' + and self.best_k_models[-1][1] <= opt['validation_cutoff'] + ): + logging.info('task solved! stopping.') + return True else: self.impatience += 1 - logging.report( - 'did not beat best {}: {} impatience: {}'.format( - opt['validation_metric'], round(self.best_valid, 4), self.impatience + if self.save_top_k == 1: + logging.report( + 'did not beat best {}: {} impatience: {}'.format( + opt['validation_metric'], round(self.best_valid, 4), self.impatience + ) + ) + else: + logging.report( + 'did not beat {} model\'s {}: {} impatience: {}'.format( + ordinal(self.save_top_k), opt['validation_metric'], round(self.best_k_models[-1][1], 4), self.impatience + ) ) - ) self.validate_time.reset() # saving @@ -636,6 +682,24 @@ def validate(self): logging.info('ran out of patience! stopping training.') return True return False + + + + def _modify_next_rank_checkpoints(self, model_rank): + if len(self.best_k_models) > self.save_top_k: + #remove last best model and its files from disk and best_k_models list to make space for new model + last_path = Path(self.best_k_models[-1][0]) + for file in last_path.parent.glob(last_path.name + '*'): + file.unlink() + del self.best_k_models[-1] + for ind in range(model_rank+1, len(self.best_k_models)): + prev_model_path = Path(self.best_k_models[ind][0]) + model_train_steps = prev_model_path.suffix[1:] + new_model_path = Path(self.opt['model_file'] + '_' + ordinal(ind+1) + '.' + model_train_steps) + for file in prev_model_path.parent.glob(prev_model_path.name + '*'): + file.rename(str(new_model_path) + ''.join(file.suffixes[1:])) + self.best_k_models[ind][0] = str(new_model_path) + def _run_single_eval( self, opt, valid_world, max_exs, datatype, is_multitask, task, index diff --git a/parlai/utils/misc.py b/parlai/utils/misc.py index 66435c69693..6a2f294cffa 100644 --- a/parlai/utils/misc.py +++ b/parlai/utils/misc.py @@ -323,6 +323,16 @@ def _report_sort_key(report_key: str) -> Tuple[str, str]: sub_key = '/'.join(fields) return (sub_key or 'all', main_key) +def ordinal(n: int): + """ + Convert a number to its ordinal counterpart + """ + if 11 <= (n % 100) <= 13: + suffix = 'th' + else: + suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)] + return str(n) + suffix + def float_formatter(f: Union[float, int]) -> str: """