forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
soft-teacher_faster-rcnn_r50-caffe_fpn_180k_semi-0.1-coco.py
84 lines (76 loc) · 2.45 KB
/
soft-teacher_faster-rcnn_r50-caffe_fpn_180k_semi-0.1-coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
_base_ = [
'../_base_/models/faster-rcnn_r50_fpn.py', '../_base_/default_runtime.py',
'../_base_/datasets/semi_coco_detection.py'
]
detector = _base_.model
detector.data_preprocessor = dict(
type='DetDataPreprocessor',
mean=[103.530, 116.280, 123.675],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32)
detector.backbone = dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe'))
model = dict(
_delete_=True,
type='SoftTeacher',
detector=detector,
data_preprocessor=dict(
type='MultiBranchDataPreprocessor',
data_preprocessor=detector.data_preprocessor),
semi_train_cfg=dict(
freeze_teacher=True,
sup_weight=1.0,
unsup_weight=4.0,
pseudo_label_initial_score_thr=0.5,
rpn_pseudo_thr=0.9,
cls_pseudo_thr=0.9,
reg_pseudo_thr=0.02,
jitter_times=10,
jitter_scale=0.06,
min_pseudo_bbox_wh=(1e-2, 1e-2)),
semi_test_cfg=dict(predict_on='teacher'))
# 10% coco train2017 is set as labeled dataset
labeled_dataset = _base_.labeled_dataset
unlabeled_dataset = _base_.unlabeled_dataset
labeled_dataset.ann_file = 'semi_anns/[email protected]'
unlabeled_dataset.ann_file = 'semi_anns/' \
unlabeled_dataset.data_prefix = dict(img='train2017/')
train_dataloader = dict(
dataset=dict(datasets=[labeled_dataset, unlabeled_dataset]))
# training schedule for 180k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=180000, val_interval=5000)
val_cfg = dict(type='TeacherStudentValLoop')
test_cfg = dict(type='TestLoop')
# learning rate policy
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=180000,
by_epoch=False,
milestones=[120000, 160000],
gamma=0.1)
]
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001))
default_hooks = dict(
checkpoint=dict(by_epoch=False, interval=10000, max_keep_ckpts=2))
log_processor = dict(by_epoch=False)
custom_hooks = [dict(type='MeanTeacherHook')]