forked from leosampaio/sketchformer
-
Notifications
You must be signed in to change notification settings - Fork 2
/
run-experiment.py
79 lines (62 loc) · 3.13 KB
/
run-experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import pprint
import utils
import models
import dataloaders
import experiments
def main():
parser = argparse.ArgumentParser(
description='Train modified transformer with sketch data')
parser.add_argument("experiment_name", default=None,
help="Reference name of experiment that you want to run")
parser.add_argument("--id", default="0", help="Experiment signature")
parser.add_argument("-o", "--output-dir", default="", help="output directory")
parser.add_argument("--exp-hparams", default=None,
help="Parameters to override defaults for experiment")
parser.add_argument("--model-hparams", default=None,
help="Parameters to override defaults for model")
parser.add_argument("-g", "--gpu", default=0, type=int, nargs='+', help="GPU ID to run on", )
parser.add_argument("--model-name", default=None,
help="Model that ou want to experiment on")
parser.add_argument("--model-id", default=None,
help="Model that ou want to experiment on")
parser.add_argument("--data-loader", default='stroke3-distributed',
help="Data loader that will provide data for model, "
"if you want to load a model")
parser.add_argument("--dataset", default=None,
help="Input data folder if you want to load a model")
parser.add_argument("-r", "--resume", default='latest', help="One of 'latest' or a checkpoint name")
parser.add_argument("--help-hps", action="store_true",
help="Prints out the hparams default values")
args = parser.parse_args()
Experiment = experiments.get_experiment_by_name(args.experiment_name)
# check for lost users in the well of despair
if args.help_hps:
hps_description = pprint.pformat(Experiment.default_hparams().values())
print("\nDefault params for experiment {}: \n{}\n\n".format(
args.experiment_name, hps_description))
return
# load model if that is what the experiment requires
utils.gpu.setup_gpu(args.gpu)
if Experiment.requires_model:
Model = models.get_model_by_name(args.model_name)
DataLoader = dataloaders.get_dataloader_by_name(args.data_loader)
# load the modelconfig
model_hps = utils.hparams.combine_hparams_into_one(
Model.default_hparams(), DataLoader.default_hparams())
utils.hparams.load_config(
model_hps, Model.get_config_filepath(args.output_dir, args.model_id))
# optional override of parameters
if args.model_hparams:
model_hps.parse(args.model_hparams)
dataset = DataLoader(model_hps, args.dataset)
model = Model(model_hps, dataset, args.output_dir, args.model_id)
model.restore_checkpoint_if_exists(args.resume)
else:
dataset, model = None, None
experiment_hps = Experiment.parse_hparams(args.exp_hparams)
# finally, run the experiment
experiment = Experiment(experiment_hps, args.id, args.output_dir)
experiment.compute(model)
if __name__ == '__main__':
main()