-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_optim.py
123 lines (104 loc) · 5.15 KB
/
train_optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import os
import time
import numpy as np
from tqdm import tqdm
from tools import save_tensor
from tools.transformer import mixup_transform
from tools.plot import VisualBoard
from tools.loader import dataLoader
from tools.parser import logger
from tools.solver import Cosine_lr_scheduler, Plateau_lr_scheduler, ALRS
scheduler_factor = {
'plateau': Plateau_lr_scheduler,
'cosine': Cosine_lr_scheduler,
'ALRS': ALRS,
}
def attack(cfg, data_root, detector_attacker, save_name, args=None):
def get_iter():
return (epoch - 1) * len(data_loader) + index
logger(cfg, args)
data_sampler = None
detector_attacker.init_universal_patch(args.patch)
data_loader = dataLoader(data_root,
input_size=cfg.DETECTOR.INPUT_SIZE, is_augment='1' in cfg.DATA.AUGMENT,
batch_size=cfg.DETECTOR.BATCH_SIZE, sampler=data_sampler, shuffle=True)
detector_attacker.gates = ['jitter', 'median_pool', 'rotate', 'p9_scale']
if args.random_erase: detector_attacker.gates.append('rerase')
p_obj = detector_attacker.patch_obj.patch
optimizer = torch.optim.Adam([p_obj], lr=cfg.ATTACKER.START_LEARNING_RATE, amsgrad=True)
scheduler = scheduler_factor[cfg.ATTACKER.scheduler](optimizer)
detector_attacker.attacker.set_optimizer(optimizer)
loss_array = []
save_tensor(detector_attacker.universal_patch, f'{save_name}' + '.png', args.save_path)
vlogger = None
if not args.debugging:
vlogger = VisualBoard(optimizer, name=args.board_name, new_process=args.new_process)
detector_attacker.vlogger = vlogger
ten_epoch_loss = 0
for epoch in range(1, cfg.ATTACKER.MAX_EPOCH + 1):
et0 = time.time()
ep_loss = 0
for index, img_tensor_batch in enumerate(tqdm(data_loader, desc=f'Epoch {epoch}')):
# for index, (img_tensor_batch, img_tensor_batch2) in enumerate(tqdm(zip(data_loader, data_loader2), desc=f'Epoch {epoch}')):
if vlogger: vlogger(epoch, get_iter())
img_tensor_batch = img_tensor_batch.to(detector_attacker.device)
if args.mixup:
img_tensor_batch = mixup_transform(x1=img_tensor_batch)
all_preds = detector_attacker.detect_bbox(img_tensor_batch)
# get position of adversarial patches
target_nums = detector_attacker.get_patch_pos_batch(all_preds)
if sum(target_nums) == 0: continue
loss = detector_attacker.attack(img_cv2, mode='optim')
ep_loss += loss
if epoch % 10 == 0:
# patch_name = f'{epoch}_{save_name}'
patch_name = f'{save_name}' + '.png'
save_tensor(detector_attacker.universal_patch, patch_name, args.save_path)
print('Saving patch to ', os.path.join(args.save_path, patch_name))
if cfg.ATTACKER.scheduler == 'ALRS':
ten_epoch_loss /= 10
scheduler.step(ten_epoch_loss)
ten_epoch_loss = 0
et1 = time.time()
ep_loss /= len(data_loader)
ten_epoch_loss += ep_loss
if cfg.ATTACKER.scheduler == 'plateau':
scheduler.step(ep_loss)
elif cfg.ATTACKER.scheduler != 'ALRS':
scheduler.step()
if vlogger:
vlogger.write_ep_loss(ep_loss)
vlogger.write_scalar(et1 - et0, 'misc/ep time')
# print(' ep loss : ', ep_loss)
loss_array.append(float(ep_loss))
np.save(os.path.join(args.save_path, save_name + '-loss.npy'), loss_array)
if __name__ == '__main__':
from tools.parser import ConfigParser
from attack.attacker import UniversalAttacker
import argparse
import warnings
warnings.filterwarnings('ignore')
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--patch', type=str, help='fine-tune from a pre-trained patch', default=None)
parser.add_argument('-m', '--attack_method', type=str, default='optim')
parser.add_argument('-cfg', '--cfg', type=str, default='optim.yaml')
parser.add_argument('-n', '--board_name', type=str, default=None)
parser.add_argument('-d', '--debugging', action='store_true')
parser.add_argument('-s', '--save_path', type=str, default='./results/exp2/optim')
parser.add_argument('-re', '--random_erase', action='store_true', default=False)
parser.add_argument('-mu', '--mixup', action='store_true', default=False)
parser.add_argument('-np', '--new_process', action='store_true', default=False)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_patch_name = args.cfg.split('.')[0] if args.board_name is None else args.board_name
args.cfg = './configs/' + args.cfg
print('-------------------------Training-------------------------')
print(' device : ', device)
print(' cfg :', args.cfg)
cfg = ConfigParser(args.cfg)
detector_attacker = UniversalAttacker(cfg, device)
cfg.show_class_label(cfg.attack_list)
data_root = cfg.DATA.TRAIN.IMG_DIR
img_names = [os.path.join(data_root, i) for i in os.listdir(data_root)]
attack(cfg, data_root, detector_attacker, save_patch_name, args)