-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_vanilla_nn.py
83 lines (63 loc) · 2.63 KB
/
mnist_vanilla_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import tensorflow as tf
from mnist_utils import load_mnist, IMAGE_SIZE, N_CLASSES
N_HIDDEN1 = 256
N_HIDDEN2 = 256
BATCH_SIZE = 200
EPOCHS = 5
x = tf.placeholder(tf.float32, [None, IMAGE_SIZE], name='x')
y = tf.placeholder(tf.float32, [None, N_CLASSES], name='y')
with tf.name_scope('hidden1'):
W = tf.Variable(tf.truncated_normal([IMAGE_SIZE, N_HIDDEN1], stddev=0.01),
name='weights')
b = tf.Variable(tf.zeros([N_HIDDEN1]), 'biases')
hidden1 = tf.nn.relu(tf.matmul(x, W) + b)
with tf.name_scope('hidden2'):
W = tf.Variable(tf.truncated_normal([N_HIDDEN1, N_HIDDEN2], stddev=0.01),
name='weights')
b = tf.Variable(tf.zeros([N_HIDDEN2]), name='biases')
hidden2 = tf.nn.relu(tf.matmul(hidden1, W) + b)
with tf.name_scope('softmax'):
W = tf.Variable(tf.truncated_normal([N_HIDDEN2, N_CLASSES], stddev=0.01),
name='weights')
b = tf.Variable(tf.zeros([N_CLASSES]), name='biases')
y_hat = tf.nn.softmax(tf.matmul(hidden2, W) + b)
# loss
cross_entropy = tf.reduce_mean(
- tf.reduce_sum(y * tf.log(y_hat),
reduction_indices=[1]))
# training step
training_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
# evaluation
predicted_class = tf.argmax(y_hat, 1)
true_class = tf.argmax(y, 1)
correct_prediction = tf.equal(predicted_class, true_class)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32))
# initialize all variables
init = tf.initialize_all_variables()
mnist = load_mnist()
train_data = mnist.train.data
train_target = mnist.train.target
n_batches = mnist.train.data.shape[0] / BATCH_SIZE
with tf.Session() as sess:
sess.run(init)
for epoch in range(EPOCHS):
print 'Epoch', epoch + 1
for batch in range(n_batches):
start_i = batch * BATCH_SIZE
end_i = start_i + BATCH_SIZE
feed_dict = {x: train_data[start_i:end_i],
y: train_target[start_i:end_i]}
sess.run(training_step, feed_dict=feed_dict)
if batch % 20 == 0:
feed_dict = {x: mnist.validation.data,
y: mnist.validation.target}
val_acc = sess.run(accuracy, feed_dict=feed_dict)
print 'Batch {0}: validation accuracy {1}'.format(batch, val_acc)
perm = range(mnist.train.data.shape[0])
np.random.shuffle(perm)
train_data = mnist.train.data[perm]
train_target = mnist.train.target[perm]
feed_dict = {x: mnist.test.data,
y: mnist.test.target}
print 'Test-set accuracy', sess.run(accuracy, feed_dict=feed_dict)