Skip to content

Commit

Permalink
chore: upgrade dependecy commit to rxn-onmt-utils + black, isort...
Browse files Browse the repository at this point in the history
  • Loading branch information
irinaespejo committed Apr 16, 2024
1 parent 61fc311 commit 68626f9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ install_requires =
rxn-chem-utils>=1.1.4
rxn-reaction-preprocessing>=2.0.2
rxn-utils>=1.1.9
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@c8591b42bf04400e6c6c5847d5edb3326481ba34 #rxn-onmt-utils without rxn-opennmt-py depedency
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@d568d0fbeb11dace1de289d0d95fe1e00b22069d #rxn-onmt-utils without rxn-opennmt-py depedency
OpenNMT-py>=3.5.1 # official onmt

[options.packages.find]
Expand Down
13 changes: 8 additions & 5 deletions src/rxn/onmt_models/scripts/rxn_onmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Tuple

import click
from rxn.onmt_utils import __version__ as onmt_utils_version
from rxn.onmt_utils.train_command import OnmtTrainCommand
from rxn.utilities.logging import setup_console_and_file_logger

from rxn.onmt_models import __version__ as onmt_models_version
from rxn.onmt_models import defaults
from rxn.onmt_models.training_files import ModelFiles, OnmtPreprocessedFiles
from rxn.onmt_models.utils import log_file_name_from_time, run_command
from rxn.onmt_utils import __version__ as onmt_utils_version
from rxn.onmt_utils.train_command import OnmtTrainCommand
from rxn.utilities.logging import setup_console_and_file_logger

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -93,7 +94,9 @@ def main(

config_file = model_files.next_config_file()

src_vocab, tgt_vocab = get_src_tgt_vocab(data=onmt_preprocessed_files.preprocess_prefix)
src_vocab, tgt_vocab = get_src_tgt_vocab(
data=onmt_preprocessed_files.preprocess_prefix
)

# Init
train_cmd = OnmtTrainCommand.train(
Expand All @@ -120,7 +123,7 @@ def main(
# Write config file
train_cmd.save_to_config_cmd(config_file)

# Actual training config file
# Actual training config file
command_and_args = train_cmd.execute_from_config_cmd(config_file)
run_command(command_and_args)

Expand Down

0 comments on commit 68626f9

Please sign in to comment.