-
Notifications
You must be signed in to change notification settings - Fork 41
/
aifgtm.py
95 lines (78 loc) · 3.76 KB
/
aifgtm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from ..utils import *
from ..attack import Attack
import math
class AIFGTM(Attack):
"""
AI-FGTM Attack
Making Adversarial Examples More Transferable and Indistinguishable (AAAI 2022) (https://arxiv.org/abs/2007.03838)
Arguments:
model_name (str): the name of surrogate model for attack.
epsilon (float): the perturbation budget.
alpha (float): the step size.
beta_1,beta_2 (float): exponential decay rates
mu_1,mu_2 (float): the decay factor
lam (float): the scale factor
epoch (int): the number of iterations.
decay (float): the decay factor for momentum calculation.
targeted (bool): targeted/untargeted attack.
random_start (bool): whether using random initialization for delta.
norm (str): the norm of perturbation, l2/linfty.
loss (str): the loss function.
device (torch.device): the device for data. If it is None, the device would be same as model
Official arguments:
epsilon=16/255, alpha=epsilon/epoch=1.6/255, epoch=10, decay=1., beta_1=0.9, beta_2=0.99, lam=1.3, mu_1=1.5, mu_2=1.9
Example script:
python main.py --input_dir ./path/to/data --output_dir adv_data/aifgtm/resnet18 --attack aifgtm --model=resnet18
python main.py --input_dir ./path/to/data --output_dir adv_data/aifgtm/resnet18 --eval
"""
def __init__(self, model_name, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., targeted=False, random_start=False,
norm='linfty', loss='crossentropy', device=None, attack='AI-FGTM',beta_1=0.9,beta_2=0.99,lam=1.3,mu_1=1.5,mu_2=1.9, **kwargs):
super().__init__(attack, model_name, epsilon, targeted, random_start, norm, loss, device)
self.alpha = alpha
self.epoch = epoch
self.decay = decay
self.beta_1 = beta_1
self.beta_2 = beta_2
self.lam = lam
self.mu_1 = mu_1
self.mu_2 = mu_2
def get_alpha(self,T,t_):
res = 0
for t in range(T):
res += (1-self.beta_1**(t+1))/math.sqrt(1-self.beta_2**(t+1))
return self.epsilon/res * (1-self.beta_1**(t_+1))/math.sqrt(1-self.beta_2**(t_+1))
def update_delta(self, delta, data, grad, alpha, **kwargs):
if self.norm == 'linfty':
delta = torch.clamp(delta + alpha * grad.tanh(), -self.epsilon, self.epsilon)
else:
grad_norm = torch.norm(grad.view(grad.size(0), -1), dim=1).view(-1, 1, 1, 1)
scaled_grad = grad / (grad_norm + 1e-20)
delta = (delta + scaled_grad * alpha).view(delta.size(0), -1).renorm(p=2, dim=0, maxnorm=self.epsilon).view_as(delta)
delta = clamp(delta, img_min-data, img_max-data)
return delta
def forward(self, data, label, **kwargs):
"""
The attack procedure for AI-FGTM
Arguments:
data: (N, C, H, W) tensor for input images
labels: (N,) tensor for ground-truth labels if untargetd, otherwise targeted labels
"""
if self.targeted:
assert len(label) == 2
label = label[1] # the second element is the targeted label tensor
data = data.clone().detach().to(self.device)
label = label.clone().detach().to(self.device)
delta = self.init_delta(data)
momentum = 0
v = 0
for _ in range(self.epoch):
logits = self.get_logits(data+ delta)
loss = self.get_loss(logits,label)
self.model.zero_grad()
grad = self.get_grad(loss,delta)
momentum = momentum + self.mu_1 * grad
v = v + self.mu_2 * grad *grad
alpha = self.get_alpha(self.epoch,_)
delta = self.update_delta(delta,data,self.lam * momentum/(torch.sqrt(v)+1e-20),alpha)
return delta.detach()