-
Notifications
You must be signed in to change notification settings - Fork 2
/
robustml_mnist.py
50 lines (39 loc) · 1.8 KB
/
robustml_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
import scipy.io
import tensorflow as tf
import robustml
import models
def load_model(sess, model, model_path):
param_file = scipy.io.loadmat(model_path)
weight_names = ['weights_conv1', 'weights_conv2', 'weights_fc1', 'weights_fc2']
bias_names = ['biases_conv1', 'biases_conv2', 'biases_fc1', 'biases_fc2']
for var_tf, var_name_mat in zip(model.W, weight_names):
var_tf.load(param_file[var_name_mat], sess)
for var_tf, var_name_mat in zip(model.b, bias_names):
bias_val = param_file[var_name_mat]
bias_val = bias_val.flatten()
var_tf.load(bias_val, sess)
class Model(robustml.model.Model):
def __init__(self, sess):
self._sess = sess
height, width, n_col = 28, 28, 1
self._input = tf.placeholder(tf.float32, (height, width, n_col)) # assuming inputs in [0, 1]
input_expanded = tf.expand_dims(self._input, axis=0)
hps = argparse.Namespace(height=height, width=width, n_col=n_col) # needed for models.LeNetSmall
self._model = models.LeNetSmall(False, hps)
self._logits = self._model.net(input_expanded)[-1]
# load the model
model_path = "models/mmr+at/2019-02-17 01:54:16 dataset=mnist nn_type=cnn_lenet_small p_norm=inf lmbd=0.5 gamma_rb=0.2 gamma_db=0.2 ae_frac=0.5 epoch=100.mat"
load_model(sess, self._model, model_path)
self._dataset = robustml.dataset.MNIST()
self._threat_model = robustml.threat_model.Linf(epsilon=0.1)
@property
def dataset(self):
return self._dataset
@property
def threat_model(self):
return self._threat_model
def classify(self, x):
logits_val = self._sess.run(self._logits, feed_dict={self._input: x})[0]
pred_label = logits_val.argmax() # label as a number
return pred_label