-
Notifications
You must be signed in to change notification settings - Fork 7
/
getmnistsample.lua
57 lines (38 loc) · 1.19 KB
/
getmnistsample.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
require 'paths'
require 'torch'
require 'image'
local datapath = 'mnist'
local testimage = 't10k-images-idx3-ubyte'
local testlabel = 't10k-labels-idx1-ubyte'
local trainimage = 'train-images-idx3-ubyte'
local trainlabel = 'train-labels-idx1-ubyte'
local labelpath = paths.concat(datapath, testlabel)
assert(paths.filep(labelpath))
local file = io.open(labelpath, "r")
local data = file:read("*a")
print(#data)
local labels = data:sub(-10000,-1)
print(#labels)
local targets = torch.LongTensor(#labels):fill(-1)
for i=1,#labels do
targets[i] = labels:byte(i)
end
assert(targets:min() ~= -1)
targets:add(1) -- 0-9 -> 1,10
file:close()
local imagepath = paths.concat(datapath, testimage)
local file = io.open(imagepath)
local data = file:read("*a")
print(#data)
local images = data:sub(16+1, -1)
print(#images, #images/(28*28))
local inputs = torch.ByteTensor(#labels, 1, 28, 28)
local ffi = require 'ffi'
local idata = inputs:data()
ffi.copy(idata, images)
inputs = inputs:float()
local indices = torch.LongTensor(16):random(1,#labels)
local samples = inputs:index(1, indices)
local display = image.toDisplayTensor(samples, 2, 4)
print(display:size())
image.save("samples/mnistsamples.png", display)