-
Notifications
You must be signed in to change notification settings - Fork 40
/
model.lua
50 lines (37 loc) · 1.25 KB
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
require 'nn'
require 'inn'
require 'cudnn'
local reshapeLastLinearLayer = paths.dofile('utils.lua').reshapeLastLinearLayer
local convertCaffeModelToTorch = paths.dofile('utils.lua').convertCaffeModelToTorch
-- 1.1. Create Network
local config = opt.netType
local createModel = paths.dofile('models/' .. config .. '.lua')
print('=> Creating model from file: models/' .. config .. '.lua')
model = createModel(opt.backend)
-- convert to accept inputs in the range 0-1 RGB format
convertCaffeModelToTorch(model,{1,1})
reshapeLastLinearLayer(model,#classes+1)
image_mean = {128/255,128/255,128/255}
if opt.algo == 'RCNN' then
classifier = model
elseif opt.algo == 'SPP' then
features = model:get(1)
classifier = model:get(3)
end
-- 2. Create Criterion
criterion = nn.CrossEntropyCriterion()
print('=> Model')
print(model)
print('=> Criterion')
print(criterion)
-- 3. If preloading option is set, preload weights from existing models appropriately
if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
classifier = torch.load(opt.retrain)
end
-- 4. Convert model to CUDA
print('==> Converting model to CUDA')
model = model:cuda()
criterion:cuda()
collectgarbage()