-
Notifications
You must be signed in to change notification settings - Fork 3
/
train_config.yaml
71 lines (70 loc) · 2.04 KB
/
train_config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
---
defaults:
- _self_
- model_config/[email protected]_module.hrnet_config
metadata:
experiment_name: HRNet_57
run_name: hrnet48x2_57_003
description: Params of the best keypoints model training
experimenter: Nikolay
data_params:
input_size: [960, 540]
num_keypoints: 57
batch_size: 8
num_workers: 8
pin_memory: true
margin: ${model.params.loss.sigma}
augmentations:
brightness: [0.8, 1.2]
color: [0.8, 1.2]
contrast: [0.8, 1.2]
gauss_noise_sigma: 30.0
prob: 0.5
data:
train:
- /workdir/data/dataset/train
val:
- /workdir/data/dataset/valid
model:
_target_: src.models.hrnet.metamodel.HRNetMetaModel
params:
device: cuda:0
nn_module:
num_refinement_stages: 0
num_heatmaps: ${model.params.nn_module.hrnet_config.num_classes}
loss:
num_refinement_stages: ${model.params.nn_module.num_refinement_stages}
stride: 2
sigma: 3.0
pred_size: [270, 480]
num_keypoints: ${data_params.num_keypoints}
l2_w: 1.0
kldiv_w: 0.0
awing_w: 0.0
optimizer:
lr: 0.0001
prediction_transform:
size: [540, 960]
amp: True # Use AMP for training
pretrain: null # Pretrain model path or null
train_params:
load_compatible: true # Load only compatible weights from pretrain
max_epochs: 200
early_stopping_epochs: 32
reduce_lr_factor: 0.5
reduce_lr_patience: 8
monitor_metric: val_loss # We monitor loss for ReduceLROnPlateau and EarlyStopping callbacks, but use the best model by val_evalai for predictions
monitor_metric_better: min
use_compile: false # Compile PyTorch model for faster performance
# Camera model is used for estimation of the final metrics during training
camera:
_target_: src.models.hrnet.prediction.CameraCreator
pitch:
_target_: src.datatools.ellipse.get_pitch
conf_thresh: 0.5 # Min confidence to consider the point detected
algorithm: "opencv_calibration_multiplane"
min_points: 5
min_focal_length: 10.0
min_points_per_plane: 6
min_points_for_refinement: 7
reliable_thresh: 57