-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_config.yaml
52 lines (52 loc) · 1.3 KB
/
train_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
---
defaults:
- _self_
- model_config/[email protected]_module.hrnet_config
metadata:
experiment_name: line
run_name: extreme_heatpoints2
description: detect two end points for each line
experimenter: RUILONG
data_params:
input_size: [960, 540] #[960, 544]
stride: 4
sigma: 1
num_keypoint_pairs: 23
batch_size: 8
num_workers: 8
pin_memory: true
augmentations:
brightness: [0.8, 1.2]
color: [0.8, 1.2]
contrast: [0.8, 1.2]
gauss_noise_sigma: 30.0
prob: 0.5
data:
train: /workdir/data/dataset/train
val: /workdir/data/dataset/valid
model:
_target_: src.models.line.metamodel.EHMMetaModel
params:
device: cuda:0
nn_module:
num_refinement_stages: 0
num_heatmaps: ${model.params.nn_module.hrnet_config.num_classes}
loss:
num_refinement_stages: ${model.params.nn_module.num_refinement_stages}
gmse_w: 1.
awing_w: 1.
sigma: 4
optimizer:
lr: 0.001
prediction_transform:
scale: ${data_params.stride}
sigma: 6
amp: True
pretrain: /workdir/data/experiments/line_extreme_heatpoints2/model-035-0.834300.pth # null or path
train_params:
max_epochs: 100
early_stopping_epochs: 50
reduce_lr_factor: 0.8
reduce_lr_patience: 8
monitor_metric: val_acc # val_acc
monitor_metric_better: max # max