Skip to content

Commit

Permalink
fix: added model_task argument for correct config in retro task + upd…
Browse files Browse the repository at this point in the history
…ate deps utils
  • Loading branch information
irinaespejo committed May 24, 2024
1 parent 98b79a0 commit 12066d2
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ install_requires =
rxn-chem-utils>=1.1.4
rxn-reaction-preprocessing>=2.0.2
rxn-utils>=1.1.9
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@f1a0b970411aac308a3cba36c942297933a4dd91 #rxn-onmt-utils without rxn-opennmt-py depedency
rxn-onmt-utils @ git+https://github.com/rxn4chemistry/rxn-onmt-utils.git@6725e6dad0c3fc563aa3eba498289c207fe9b8e2 #rxn-onmt-utils without rxn-opennmt-py depedency
OpenNMT-py>=3.5.1 # official onmt

[options.packages.find]
Expand Down
3 changes: 3 additions & 0 deletions src/rxn/onmt_models/scripts/rxn_onmt_continue_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
default=100000,
help="Number of steps, including steps from the initial training run.",
)
@click.option("--model_task", type=str, required=True)
def main(
batch_size: int,
data_weights: Tuple[int, ...],
Expand All @@ -66,6 +67,7 @@ def main(
preprocess_dir: str,
train_from: Optional[str],
train_num_steps: int,
model_task: str,
) -> None:
"""Continue training for an OpenNMT model.
Expand Down Expand Up @@ -111,6 +113,7 @@ def main(
train_steps=train_num_steps,
no_gpu=no_gpu,
data_weights=data_weights,
model_task=model_task,
)

# Write config file
Expand Down
3 changes: 3 additions & 0 deletions src/rxn/onmt_models/scripts/rxn_onmt_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
@click.option("--warmup_steps", default=defaults.WARMUP_STEPS)
@click.option("--report_every", default=1000)
@click.option("--save_checkpoint_steps", default=5000)
@click.option("--model_task", type=str, required=True)
def main(
batch_size: int,
data_weights: Tuple[int, ...],
Expand All @@ -69,6 +70,7 @@ def main(
warmup_steps: int,
report_every: int,
save_checkpoint_steps: int,
model_task: str,
) -> None:
"""Finetune an OpenNMT model."""

Expand Down Expand Up @@ -122,6 +124,7 @@ def main(
data_weights=data_weights,
report_every=report_every,
save_checkpoint_steps=save_checkpoint_steps,
model_task=model_task,
)

# Write config file
Expand Down
3 changes: 3 additions & 0 deletions src/rxn/onmt_models/scripts/rxn_onmt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def check_rnn_vs_hidden_size(hidden_size: int, rnn_size: int) -> int:
@click.option("--transformer_ff", default=defaults.TRANSFORMER_FF)
@click.option("--warmup_steps", default=defaults.WARMUP_STEPS)
@click.option("--word_vec_size", default=defaults.WORD_VEC_SIZE)
@click.option("--model_task", type=str, required=True)
def main(
batch_size: int,
data_weights: Tuple[int, ...],
Expand All @@ -94,6 +95,7 @@ def main(
transformer_ff: int,
warmup_steps: int,
word_vec_size: int,
model_task: str,
) -> None:
"""Train an OpenNMT model.
Expand Down Expand Up @@ -143,6 +145,7 @@ def main(
word_vec_size=word_vec_size,
no_gpu=no_gpu,
data_weights=data_weights,
model_task=model_task,
)

# Write config file
Expand Down

0 comments on commit 12066d2

Please sign in to comment.