Skip to content

Commit

Permalink
fix: upgraded onmt_preprocess to onmt_build_vocab, added wrapper argu…
Browse files Browse the repository at this point in the history
…ments
  • Loading branch information
irinaespejo committed Apr 9, 2024
1 parent 40da86d commit c0b47cc
Showing 1 changed file with 107 additions and 17 deletions.
124 changes: 107 additions & 17 deletions src/rxn/onmt_models/scripts/rxn_onmt_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from typing import List, Optional, Tuple

import click
import yaml
from rxn.chemutils.tokenization import ensure_tokenized_file
from rxn.onmt_models import __version__ as onmt_models_version
from rxn.onmt_models import defaults
from rxn.onmt_models.training_files import OnmtPreprocessedFiles, RxnPreprocessingFiles
from rxn.onmt_models.utils import log_file_name_from_time, run_command
from rxn.onmt_utils import __version__ as onmt_utils_version
from rxn.onmt_utils.train_command import preprocessed_id_names
from rxn.utilities.files import (
Expand All @@ -15,11 +20,6 @@
)
from rxn.utilities.logging import setup_console_and_file_logger

from rxn.onmt_models import __version__ as onmt_models_version
from rxn.onmt_models import defaults
from rxn.onmt_models.training_files import OnmtPreprocessedFiles, RxnPreprocessingFiles
from rxn.onmt_models.utils import log_file_name_from_time, run_command

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down Expand Up @@ -51,6 +51,89 @@ def determine_train_dataset(
return src, tgt


def get_build_vocab_config_file(
train_srcs: List[PathLike],
train_tgts: List[PathLike],
valid_src: List[PathLike],
valid_tgt: List[PathLike],
save_data: Path,
share_vocab: bool = True,
overwrite: bool = True,
src_seq_length: int = 3000,
tgt_seq_length: int = 3000,
src_vocab_size: int = 3000,
tgt_vocab_size: int = 3000,
) -> Path:
"""Wrapper function of the legacy cli `onmt_preprocessed` arguments.
The goal is to make them compatible with ONMT v.3.5.1 cli `onmt_build_vocab`.
The function takes the arguments of former onmt_preprocessed cli and dumps
them into a `config.yaml` file with a specific structure compatible with `onmt_build_vocab`.
The upgraded `onmt_build_vocab` takes them as `onmt_build_vocab -config config.yaml`.
Args:
train_srcs (List[PathLike]): List of train source data files
train_tgts (List[PathLike]): List of train target data files
valid_src (List[PathLike]): List of validation source data files
valid_tgt (List[PathLike]): List of validation target data files
save_data (PathLike): Save vocabulary data directory
share_vocab (bool, optional): Share vocab. Defaults to True.
overwrite (bool, optional): Overwrite output directory. Defaults to True.
src_seq_length (int, optional): src_seq_length. Defaults to 3000.
tgt_seq_length (int, optional): tgt_seq_length. Defaults to 3000.
src_vocab_size (int, optional): src_vocab_size. Defaults to 3000.
tgt_vocab_size (int, optional): tgt_vocab_size. Defaults to 3000.
Returns:
PathLike: Path of the config.yaml which is in directory `save_data`
"""

# Build dictionary with build vocab config content
# See structure https://opennmt.net/OpenNMT-py/quickstart.html (Step 1: Prepare the data)
build_vocab_config = dict()

# Arguments save data
build_vocab_config["save_data"] = str(save_data.parent)
build_vocab_config["src_vocab"] = str(
save_data.parent / (save_data.name + ".vocab.src")
)
build_vocab_config["tgt_vocab"] = str(
save_data.parent / (save_data.name + ".vocab.tgt")
)

# Other arguments
build_vocab_config["overwrite"] = overwrite
build_vocab_config["share_vocab"] = share_vocab
build_vocab_config["src_seq_length"] = src_seq_length
build_vocab_config["tgt_seq_length"] = tgt_seq_length
build_vocab_config["src_vocab_size"] = src_vocab_size
build_vocab_config["tgt_vocab_size"] = tgt_vocab_size

# Arguments data paths (train)
build_vocab_config["data"] = dict()
# TODO: raise error if lengths: train_srcs, train_tgts, valid_src, valid_tgt are different
number_corpus = len(train_srcs)
for i in range(number_corpus):
build_vocab_config["data"][f"corpus_{i+1}"] = {
"path_src": str(train_srcs[i]),
"path_tgt": str(train_tgts[i]),
}

# Arguments data paths (valid)
build_vocab_config["data"]["valid"] = {
"path_src": str(valid_src),
"path_tgt": str(valid_tgt),
}

# Path to same yaml file
config_file_path = save_data.parent / (save_data.name + "_build_vocab_config.yaml")

# Save file that will be -config argument of onmt_build_vocab
with open(config_file_path, "w+") as file:
yaml.dump(build_vocab_config, file)

return config_file_path


@click.command()
@click.option(
"--input_dir",
Expand Down Expand Up @@ -180,21 +263,28 @@ def main(
valid_src = ensure_tokenized_file(valid_src)
valid_tgt = ensure_tokenized_file(valid_tgt)

# Create config file for onmt_build_vocab for OpenNMT v.3.5.1
# Dump train_srcs, train_tgts, valid_src, valid_tgt etc and return path
config_file_path = get_build_vocab_config_file(
train_srcs=train_srcs,
train_tgts=train_tgts,
valid_src=valid_src,
valid_tgt=valid_tgt,
save_data=onmt_preprocessed_files.preprocess_prefix,
share_vocab=True,
overwrite=True,
src_seq_length=3000,
tgt_seq_length=3000,
src_vocab_size=3000,
tgt_vocab_size=3000,
)

# yapf: disable
command_and_args = [
str(e) for e in [
'onmt_preprocess',
'-train_src', *train_srcs,
'-train_tgt', *train_tgts,
'-valid_src', valid_src,
'-valid_tgt', valid_tgt,
'-save_data', onmt_preprocessed_files.preprocess_prefix,
'-src_seq_length', 3000,
'-tgt_seq_length', 3000,
'-src_vocab_size', 3000,
'-tgt_vocab_size', 3000,
'-share_vocab',
'-overwrite',
'onmt_build_vocab',
'-config', config_file_path,
'-n_sample', 3000,
]
]
# yapf: enable
Expand Down

0 comments on commit c0b47cc

Please sign in to comment.