forked from eriche2016/seq2seq-1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gpu2cpu.lua
81 lines (69 loc) · 1.99 KB
/
gpu2cpu.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
--[[
Convert GPU model to CPU model
Woohyun Kim ([email protected])
--]]
-- torch7
require('torch')
require('nn')
require('nngraph')
require('optim')
require('lfs')
-- util
require('util.parser')
require('util.filereader')
require('util.wordindexer')
require('util.inputloader')
require('util.batcher')
require('util.tensorbatcher')
require('util.Squeeze')
-- model
require('model.RNN')
require('model.LSTM')
require('model.GRU')
require('model.HighwayMLP')
require('model.TDNN')
require('model.LSTMTDNN')
-- criterion
require('model.HSMClass')
require('model.HLogSoftMax')
-- network & optimizer
require('network')
require('optimizer')
--BatchLoader = require 'util.BatchLoaderUnk'
model_utils = require 'util.model_utils'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Convert GPU model to CPU model')
cmd:text()
cmd:text('Options')
-- model
cmd:option('-checkpoint','cv/checkpoint.t7','GPU model to convert')
-- GPU
cmd:option('-gpuid',-1,'GPU device')
opt = cmd:parse(arg)
-- GPU
if opt.gpuid >= 0 then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpuid + 1)
end
print('loading ' .. opt.checkpoint .. ' for converting')
local checkpoint = torch.load(opt.checkpoint)
local checkpoint_4cpu = {}
checkpoint_4cpu.opt = checkpoint.opt
checkpoint_4cpu.opt.gpuid = -1
checkpoint_4cpu.indexer = checkpoint.indexer
local network = Network(checkpoint_4cpu.opt) -- create a new network
network.opt = checkpoint_4cpu.opt
network.rnn = checkpoint.network.rnn:double()
network.criterion = checkpoint.network.criterion:double()
local init_state = {}
for k,v in pairs(checkpoint.network.init_state) do init_state[k] = v:double():clone() end
local init_state_global = {}
for k,v in pairs(checkpoint.network.init_state_global) do init_state_global[k] = v:double():clone() end
network.init_state = init_state
checkpoint_4cpu.network = network
local savefile = string.gsub(opt.checkpoint, paths.extname(opt.checkpoint), "4cpu.t7")
print('saving checkpoint to ' .. savefile)
torch.save(savefile, checkpoint_4cpu)
print('saved.')