-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_train.py
105 lines (83 loc) · 4.09 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from src.utils import hydra_custom_resolvers
from src import utils
import hydra
from omegaconf import OmegaConf, DictConfig
from src.utils import general_helpers
from typing import List
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
from pytorch_lightning.loggers import LightningLoggerBase
log = utils.get_pylogger(__name__)
def run_train(cfg: DictConfig):
assert cfg.output_dir is not None, "Path to the directory in which the predictions will be written must be given"
cfg.output_dir = general_helpers.get_absolute_path(cfg.output_dir)
log.info(f"Output directory: {cfg.output_dir}")
# Set seed for random number generators in PyTorch, Numpy and Python (random)
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating data module <{cfg.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule, _recursive_=False)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model, datamodule=datamodule)
datamodule.set_tokenizer(model.tokenizer)
# If defined, use the model's collate function (otherwise proceed with the PyTorch's default collate_fn)
if getattr(model, "collator", None):
datamodule.set_collate_fn(model.collator.collate_fn)
# ~~~ Precautionary check ~~~
_linearization_data = datamodule.dataset_parameters["train"]["dataset"]["linearization_class_id"]
_linearization_model = model.linearization_class.identifier
if _linearization_data != _linearization_model:
log.info(
f"The linearization types do not match: "
f"dataset `{_linearization_data}` and model `{_linearization_model}`"
)
_linearization_constraint_module = (
model.constraint_module.linearization_class_id if model.constraint_module else None
)
if _linearization_constraint_module and _linearization_data != _linearization_constraint_module:
log.info(
f"The linearization types do not match: "
f"dataset `{_linearization_data}` and constraint module `{_linearization_constraint_module}`"
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
log.info("Instantiating callbacks...")
callbacks: List[Callback] = general_helpers.instantiate_callbacks(cfg.get("callback"))
log.info("Instantiating loggers...")
logger: List[LightningLoggerBase] = general_helpers.instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
logging_object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(logging_object_dict)
log.info("Starting training!")
model.output_dir = cfg.output_dir
if cfg.resume_from_checkpoint:
log.info(f"Resuming from checkpoint: {cfg.resume_from_checkpoint}")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.resume_from_checkpoint)
ckpt_path = trainer.checkpoint_callback.best_model_path
log.info(f"Best ckpt path: {ckpt_path}")
if cfg.get("test"):
log.info("Starting testing!")
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
model.output_dir = cfg.output_dir
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
# for predictions use trainer.predict(...)
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
metric_dict = trainer.callback_metrics
log.info("Metrics dict:")
log.info(metric_dict)
@hydra.main(version_base="1.2", config_path="configs", config_name="train_root")
def main(hydra_config: DictConfig):
utils.run_task(hydra_config, run_train)
if __name__ == "__main__":
main()