diff --git a/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py b/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py index 1004f92..463dd03 100644 --- a/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py +++ b/src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py @@ -114,8 +114,7 @@ def main( ) # Write config file - command_and_args = train_cmd.save_to_config_cmd(config_file) - run_command(command_and_args) + train_cmd.save_to_config_cmd(config_file) # Actual training config file command_and_args = train_cmd.execute_from_config_cmd(config_file) diff --git a/src/rxn/onmt_models/scripts/rxn_onmt_finetune.py b/src/rxn/onmt_models/scripts/rxn_onmt_finetune.py index c30fed8..4e6e69b 100644 --- a/src/rxn/onmt_models/scripts/rxn_onmt_finetune.py +++ b/src/rxn/onmt_models/scripts/rxn_onmt_finetune.py @@ -112,7 +112,7 @@ def main( dropout=dropout, keep_checkpoint=keep_checkpoint, learning_rate=learning_rate, - rnn_size=rnn_size, + hidden_size=rnn_size, save_model=model_files.model_prefix, seed=seed, train_from=train_from, @@ -125,8 +125,7 @@ def main( ) # Write config file - command_and_args = train_cmd.save_to_config_cmd(config_file) - run_command(command_and_args) + train_cmd.save_to_config_cmd(config_file) # Actual training config file command_and_args = train_cmd.execute_from_config_cmd(config_file) diff --git a/src/rxn/onmt_models/scripts/rxn_onmt_train.py b/src/rxn/onmt_models/scripts/rxn_onmt_train.py index fa33178..497fdf2 100644 --- a/src/rxn/onmt_models/scripts/rxn_onmt_train.py +++ b/src/rxn/onmt_models/scripts/rxn_onmt_train.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Tuple import click @@ -15,7 +16,7 @@ logger.addHandler(logging.NullHandler()) -def get_src_tgt_vocab(data): +def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]: src_vocab = data.parent / (data.name + ".vocab.src") tgt_vocab = data.parent / (data.name + ".vocab.tgt") return src_vocab, tgt_vocab