diff --git a/src/rxn/onmt_utils/train_command.py b/src/rxn/onmt_utils/train_command.py index b316040..1511ba7 100644 --- a/src/rxn/onmt_utils/train_command.py +++ b/src/rxn/onmt_utils/train_command.py @@ -1,8 +1,14 @@ +import logging from enum import Flag -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple from rxn.utilities.files import PathLike +from .model_introspection import get_model_rnn_size + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + class RxnCommand(Flag): """ @@ -246,7 +252,6 @@ def finetune( data: PathLike, dropout: float, learning_rate: float, - rnn_size: int, save_model: PathLike, seed: int, train_from: PathLike, @@ -256,7 +261,15 @@ def finetune( data_weights: Tuple[int, ...], report_every: int, save_checkpoint_steps: int, + rnn_size: Optional[int] = None, ) -> "OnmtTrainCommand": + if rnn_size is None: + # In principle, the rnn_size should not be needed for finetuning. However, + # when resetting the decay algorithm for the learning rate, this value + # is necessary - and does not get it from the model checkpoint (OpenNMT bug). + rnn_size = get_model_rnn_size(train_from) + logger.info(f"Loaded the value of rnn_size from the model: {rnn_size}.") + return cls( command_type=RxnCommand.F, no_gpu=no_gpu,