forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolox_s_8x8_300e_coco.py
165 lines (152 loc) · 4.89 KB
/
yolox_s_8x8_300e_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']
img_scale = (640, 640) # height, width
# model settings
model = dict(
type='YOLOX',
input_size=img_scale,
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),
bbox_head=dict(
type='YOLOXHead', num_classes=80, in_channels=128, feat_channels=128),
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
# dataset settings
data_root = 'data/coco/'
dataset_type = 'CocoDataset'
train_pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
type='RandomAffine',
scaling_ratio_range=(0.1, 2),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='MixUp',
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
# According to the official implementation, multi-scale
# training is not considered here but in the
# 'mmdet/models/detectors/yolox.py'.
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(
type='Pad',
pad_to_square=True,
# If the image is three-channel, the pad value needs
# to be set separately for each channel.
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
train_dataset = dict(
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
filter_empty_gt=False,
),
pipeline=train_pipeline)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
persistent_workers=True,
train=train_dataset,
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
# optimizer
# default 8 gpu
optimizer = dict(
type='SGD',
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
nesterov=True,
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)
max_epochs = 300
num_last_epochs = 15
resume_from = None
interval = 10
# learning policy
lr_config = dict(
_delete_=True,
policy='YOLOX',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)
runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
custom_hooks = [
dict(
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(
save_best='auto',
# The evaluation interval is 'interval' when running epoch is
# less than ‘max_epochs - num_last_epochs’.
# The evaluation interval is 1 when running epoch is greater than
# or equal to ‘max_epochs - num_last_epochs’.
interval=interval,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
metric='bbox')
log_config = dict(interval=50)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (8 samples per GPU)
auto_scale_lr = dict(base_batch_size=64)