forked from cmusatyalab/openface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch-represent.lua
59 lines (50 loc) · 1.5 KB
/
batch-represent.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
local ffi = require 'ffi'
local batchNumber, nImgs = 0
function batchRepresent()
local loadSize = {3, opt.imgDim, opt.imgDim}
local dumpLoader = dataLoader{
paths = {opt.data},
loadSize = loadSize,
sampleSize = loadSize,
split = 0,
verbose = true
}
nImgs = dumpLoader:sizeTest()
print('nImgs: ', nImgs)
assert(nImgs > 0, "Failed to get nImgs")
batchNumber = 0
for i=1,math.ceil(nImgs/opt.batchSize) do
local indexStart = (i-1) * opt.batchSize + 1
local indexEnd = math.min(nImgs, indexStart + opt.batchSize - 1)
local inputs, labels = dumpLoader:get(indexStart, indexEnd)
local paths = {}
for i=indexStart,indexEnd do
table.insert(paths, ffi.string(dumpLoader.imagePath[i]:data()))
end
repBatch(paths, inputs, labels)
if i % 5 == 0 then
collectgarbage()
end
end
if opt.cuda then
cutorch.synchronize()
end
end
function repBatch(paths, inputs, labels)
-- labels:size(1) is equal to batchSize except for the last iteration if
-- the number of images isn't equal to the batch size.
local n = labels:size(1)
batchNumber = batchNumber + n
if opt.cuda then
inputs = inputs:cuda()
end
local embeddings = model:forward(inputs):float()
if opt.cuda then
cutorch.synchronize()
end
for i=1,n do
labelsCSV:write({labels[i], paths[i]})
repsCSV:write(embeddings[i]:totable())
end
print(('Represent: %d/%d'):format(batchNumber, nImgs))
end