From 4379ae0eb9f98558b292bb221489aef5e2a5cce9 Mon Sep 17 00:00:00 2001 From: KCool Date: Wed, 8 Aug 2018 13:55:55 +0200 Subject: [PATCH] Add CPU option Added "if CUDA" to all cuda() calls --- code/trainer.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/code/trainer.py b/code/trainer.py index a6d4180f..874ac15c 100644 --- a/code/trainer.py +++ b/code/trainer.py @@ -357,7 +357,8 @@ def sampling(self, split_dir): else: netG = G_NET() netG.apply(weights_init) - netG.cuda() + if cfg.CUDA: + netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) @@ -365,13 +366,15 @@ def sampling(self, split_dir): torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) - text_encoder = text_encoder.cuda() + if cfg.CUDA: + text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) - noise = noise.cuda() + if cfg.CUDA: + noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ @@ -440,7 +443,8 @@ def gen_example(self, data_dic): torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) - text_encoder = text_encoder.cuda() + if cfg.CUDA: + text_encoder = text_encoder.cuda() text_encoder.eval() # the path to save generated images @@ -454,7 +458,8 @@ def gen_example(self, data_dic): torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) - netG.cuda() + if cfg.CUDA: + netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) @@ -466,11 +471,13 @@ def gen_example(self, data_dic): captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) - captions = captions.cuda() - cap_lens = cap_lens.cuda() + if cfg.CUDA: + captions = captions.cuda() + cap_lens = cap_lens.cuda() for i in range(1): # 16 noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) - noise = noise.cuda() + if cfg.CUDA: + noise = noise.cuda() ####################################################### # (1) Extract text embeddings ######################################################