Skip to content

Commit

Permalink
fix: adapt finetune + continue_train to same changes as in train
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 16, 2024
1 parent 68626f9 commit 4a3f439
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def main(
)

# Write config file
command_and_args = train_cmd.save_to_config_cmd(config_file)
run_command(command_and_args)
train_cmd.save_to_config_cmd(config_file)

# Actual training config file
command_and_args = train_cmd.execute_from_config_cmd(config_file)
Expand Down
5 changes: 2 additions & 3 deletions src/rxn/onmt_models/scripts/rxn_onmt_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(
dropout=dropout,
keep_checkpoint=keep_checkpoint,
learning_rate=learning_rate,
rnn_size=rnn_size,
hidden_size=rnn_size,
save_model=model_files.model_prefix,
seed=seed,
train_from=train_from,
Expand All @@ -125,8 +125,7 @@ def main(
)

# Write config file
command_and_args = train_cmd.save_to_config_cmd(config_file)
run_command(command_and_args)
train_cmd.save_to_config_cmd(config_file)

# Actual training config file
command_and_args = train_cmd.execute_from_config_cmd(config_file)
Expand Down
3 changes: 2 additions & 1 deletion src/rxn/onmt_models/scripts/rxn_onmt_train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import Tuple

import click
Expand All @@ -15,7 +16,7 @@
logger.addHandler(logging.NullHandler())


def get_src_tgt_vocab(data):
def get_src_tgt_vocab(data: Path) -> Tuple[Path, Path]:
src_vocab = data.parent / (data.name + ".vocab.src")
tgt_vocab = data.parent / (data.name + ".vocab.tgt")
return src_vocab, tgt_vocab
Expand Down

0 comments on commit 4a3f439

Please sign in to comment.