-
Notifications
You must be signed in to change notification settings - Fork 24
/
generate_plots.lua
64 lines (48 loc) · 1.93 KB
/
generate_plots.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
require 'audio'
require 'gnuplot'
local cmd = torch.CmdLine()
cmd:text('generate_plots.lua - plots the loss and gradNorm curve for a given session')
cmd:text('')
cmd:text('Session:')
cmd:option('-session','default','The name of the session for which to generate plots')
cmd:text('')
local args = cmd:parse(arg)
local session_path = 'sessions/'..args.session
path.mkdir(session_path..'/plots')
local session = torch.load(session_path..'/session.t7')
local losses = torch.load(session_path..'/losses.t7')
local grads = torch.load(session_path..'/gradNorms.t7')
local audio_data_path = 'datasets/'..session.dataset..'/data'
local aud,sample_rate = audio.load(audio_data_path..'/p0001.wav')
local n_tsteps = math.floor((aud:size(1) - session.big_frame_size) / session.seq_len)
print(#losses..' iterations')
local lossesTensor = torch.Tensor(#losses)
for i=1,#losses do
lossesTensor[i] = losses[i]
end
local gradsTensor = torch.Tensor(#grads)
for i=1,#grads do
gradsTensor[i] = grads[i]
end
print('Plotting loss curve ...')
local loss_max = lossesTensor:view(-1,n_tsteps):max(2)
lossesTensor:clamp(0,lossesTensor:view(-1,n_tsteps):max(2)[{{2,-1}}]:max())
loss_max = lossesTensor:view(-1,n_tsteps):max(2)
local loss_min = lossesTensor:view(-1,n_tsteps):min(2)
local loss_mean = lossesTensor:view(-1,n_tsteps):mean(2)
gnuplot.pdffigure(session_path..'/plots/loss_curve.pdf')
gnuplot.raw('set size rectangle')
gnuplot.raw('set xlabel "minibatches"')
gnuplot.raw('set ylabel "NLL (bits)"')
gnuplot.plot({'min',loss_min,'-'},{'max',loss_max,'-'},{'mean',loss_mean,'-'})
gnuplot.plotflush()
gnuplot.close()
print('Plotting grad curve ...')
gnuplot.pdffigure(session_path..'/plots/grad_curve.pdf')
gnuplot.raw('set size rectangle')
gnuplot.raw('set xlabel "iterations"')
gnuplot.raw('set ylabel "norm(dparam)"')
gnuplot.plot({gradsTensor,'-'})
gnuplot.plotflush()
gnuplot.close()
print('Done!')