From 68626f90f1432f7103e0cfe11fbbb717c533d850 Mon Sep 17 00:00:00 2001 From: irinaespejo Date: Tue, 16 Apr 2024 15:03:20 +0200 Subject: [PATCH] chore: upgrade dependecy commit to rxn-onmt-utils + black, isort... --- setup.cfg | 2 +- src/rxn/onmt_models/scripts/rxn_onmt_train.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/setup.cfg b/setup.cfg index 5d4c384..1e72725 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ install_requires = rxn-chem-utils>=1.1.4 rxn-reaction-preprocessing>=2.0.2 rxn-utils>=1.1.9 - rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@c8591b42bf04400e6c6c5847d5edb3326481ba34 #rxn-onmt-utils without rxn-opennmt-py depedency + rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@d568d0fbeb11dace1de289d0d95fe1e00b22069d #rxn-onmt-utils without rxn-opennmt-py depedency OpenNMT-py>=3.5.1 # official onmt [options.packages.find] diff --git a/src/rxn/onmt_models/scripts/rxn_onmt_train.py b/src/rxn/onmt_models/scripts/rxn_onmt_train.py index 14352fe..fa33178 100644 --- a/src/rxn/onmt_models/scripts/rxn_onmt_train.py +++ b/src/rxn/onmt_models/scripts/rxn_onmt_train.py @@ -2,13 +2,14 @@ from typing import Tuple import click +from rxn.onmt_utils import __version__ as onmt_utils_version +from rxn.onmt_utils.train_command import OnmtTrainCommand +from rxn.utilities.logging import setup_console_and_file_logger + from rxn.onmt_models import __version__ as onmt_models_version from rxn.onmt_models import defaults from rxn.onmt_models.training_files import ModelFiles, OnmtPreprocessedFiles from rxn.onmt_models.utils import log_file_name_from_time, run_command -from rxn.onmt_utils import __version__ as onmt_utils_version -from rxn.onmt_utils.train_command import OnmtTrainCommand -from rxn.utilities.logging import setup_console_and_file_logger logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -93,7 +94,9 @@ def main( config_file = model_files.next_config_file() - src_vocab, tgt_vocab = get_src_tgt_vocab(data=onmt_preprocessed_files.preprocess_prefix) + src_vocab, tgt_vocab = get_src_tgt_vocab( + data=onmt_preprocessed_files.preprocess_prefix + ) # Init train_cmd = OnmtTrainCommand.train( @@ -120,7 +123,7 @@ def main( # Write config file train_cmd.save_to_config_cmd(config_file) - # Actual training config file + # Actual training config file command_and_args = train_cmd.execute_from_config_cmd(config_file) run_command(command_and_args)