-
Notifications
You must be signed in to change notification settings - Fork 16
/
generate.py
64 lines (53 loc) · 2.08 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torchvision.utils as vutils
import time
from draw_model import DRAWModel
parser = argparse.ArgumentParser()
parser.add_argument('-load_path', required=True, help='Checkpoint to load path from')
parser.add_argument('-num_output', default=36, help='Number of generated outputs')
parser.add_argument('-t', default=None, help='Number of glimpses.')
args = parser.parse_args()
# Load the checkpoint file.
state_dict = torch.load(args.load_path)
# Set the device to run on: GPU or CPU.
device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")
# Get the 'params' dictionary from the loaded state_dict.
params = state_dict['params']
# Set the number of glimpses.
# Best to just use the same value which was used for training.
params['T'] = int(args.t) if(args.t) else params['T']
# Load the model
model = DRAWModel(params).to(device)
# Load the trained parameters.
model.load_state_dict(state_dict['model'])
print('\n')
print(model)
start_time = time.time()
print('*'*25)
print("Generating Image...")
# Generate images.
with torch.no_grad():
x = model.generate(int(args.num_output))
time_elapsed = time.time() - start_time
print('\nDONE!')
print('Time taken to generate image: %.2fs' % (time_elapsed))
print('\nSaving generated image...')
fig = plt.figure(figsize=(int(np.sqrt(int(args.num_output)))*2, int(np.sqrt(int(args.num_output)))*2))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(
x[-1], nrow=int(np.sqrt(int(args.num_output))), padding=1, normalize=True, pad_value=1).cpu(), (1, 2, 0)))
plt.savefig("Generated_Image")
plt.close('all')
# Create animation for the generation.
fig = plt.figure(figsize=(int(np.sqrt(int(args.num_output)))*2, int(np.sqrt(int(args.num_output)))*2))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in x]
anim = animation.ArtistAnimation(fig, ims, interval=200, repeat_delay=2000, blit=True)
anim.save('draw_generate.gif', dpi=100, writer='imagemagick')
print('DONE!')
print('-'*50)
plt.show()