Skip to content

Commit

Permalink
fix: pass config file to onmt_train
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 16, 2024
1 parent 12554b4 commit fc5dc17
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions src/rxn/onmt_models/scripts/rxn_onmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
from typing import Tuple

import click
from rxn.onmt_utils import __version__ as onmt_utils_version
from rxn.onmt_utils.train_command import OnmtTrainCommand
from rxn.utilities.logging import setup_console_and_file_logger

from rxn.onmt_models import __version__ as onmt_models_version
from rxn.onmt_models import defaults
from rxn.onmt_models.training_files import ModelFiles, OnmtPreprocessedFiles
from rxn.onmt_models.utils import log_file_name_from_time, run_command
from rxn.onmt_utils import __version__ as onmt_utils_version
from rxn.onmt_utils.train_command import OnmtTrainCommand
from rxn.utilities.logging import setup_console_and_file_logger

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def get_src_tgt_vocab(data):
src_vocab = data.parent / (data.name + ".vocab.src")
tgt_vocab = data.parent / (data.name + ".vocab.tgt")
return src_vocab, tgt_vocab


@click.command(context_settings=dict(show_default=True))
@click.option("--batch_size", default=defaults.BATCH_SIZE)
@click.option(
Expand Down Expand Up @@ -88,15 +93,20 @@ def main(

config_file = model_files.next_config_file()

src_vocab, tgt_vocab = get_src_tgt_vocab(data=onmt_preprocessed_files.preprocess_prefix)

# Init
train_cmd = OnmtTrainCommand.train(
batch_size=batch_size,
data=onmt_preprocessed_files.preprocess_prefix,
src_vocab=src_vocab,
tgt_vocab=tgt_vocab,
dropout=dropout,
heads=heads,
keep_checkpoint=keep_checkpoint,
layers=layers,
learning_rate=learning_rate,
rnn_size=rnn_size,
hidden_size=rnn_size,
save_model=model_files.model_prefix,
seed=seed,
train_steps=train_num_steps,
Expand All @@ -108,10 +118,14 @@ def main(
)

# Write config file
command_and_args = train_cmd.save_to_config_cmd(config_file)
run_command(command_and_args)
import ipdb
ipdb.set_trace()

train_cmd.save_to_config_cmd(config_file)
#import ipdb
#ipdb.set_trace()

# Actual training config file
# Actual training config file
command_and_args = train_cmd.execute_from_config_cmd(config_file)
run_command(command_and_args)

Expand Down

0 comments on commit fc5dc17

Please sign in to comment.