From de68fbd7538a33ea7cb2e79d93f18a04987d7d65 Mon Sep 17 00:00:00 2001
From: David Stap <dd.stap@gmail.com>
Date: Thu, 7 Mar 2019 16:43:23 +0100
Subject: [PATCH 1/3] Python 3.6+ and Pytorch 1.0+

---
 code/GlobalAttention.py |  6 +++---
 code/datasets.py        |  6 +++---
 code/main.py            |  6 +++---
 code/miscc/config.py    |  4 ++--
 code/miscc/losses.py    |  4 ++--
 code/miscc/utils.py     | 19 ++++++++++++-------
 code/model.py           | 18 +++++++++++++++---
 code/pretrain_DAMSM.py  |  2 +-
 code/trainer.py         | 28 +++++++++++++++++-----------
 eval/eval.py            | 21 +++++++++++----------
 eval/miscc/utils.py     |  6 +++---
 eval/model.py           | 12 ++++++------
 12 files changed, 78 insertions(+), 54 deletions(-)

diff --git a/code/GlobalAttention.py b/code/GlobalAttention.py
index 501fb720..c5322172 100644
--- a/code/GlobalAttention.py
+++ b/code/GlobalAttention.py
@@ -48,7 +48,7 @@ def func_attention(query, context, gamma1):
     attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper
     # --> batch*sourceL x queryL
     attn = attn.view(batch_size*sourceL, queryL)
-    attn = nn.Softmax()(attn)  # Eq. (8)
+    attn = nn.Softmax(dim=1)(attn)  # Eq. (8)
 
     # --> batch x sourceL x queryL
     attn = attn.view(batch_size, sourceL, queryL)
@@ -57,7 +57,7 @@ def func_attention(query, context, gamma1):
     attn = attn.view(batch_size*queryL, sourceL)
     #  Eq. (9)
     attn = attn * gamma1
-    attn = nn.Softmax()(attn)
+    attn = nn.Softmax(dim=1)(attn)
     attn = attn.view(batch_size, queryL, sourceL)
     # --> batch x sourceL x queryL
     attnT = torch.transpose(attn, 1, 2).contiguous()
@@ -73,7 +73,7 @@ class GlobalAttentionGeneral(nn.Module):
     def __init__(self, idf, cdf):
         super(GlobalAttentionGeneral, self).__init__()
         self.conv_context = conv1x1(cdf, idf)
-        self.sm = nn.Softmax()
+        self.sm = nn.Softmax(dim=1)
         self.mask = None
 
     def applyMask(self, mask):
diff --git a/code/datasets.py b/code/datasets.py
index 24ffdc4a..a6cb1f8b 100644
--- a/code/datasets.py
+++ b/code/datasets.py
@@ -80,7 +80,7 @@ def get_imgs(img_path, imsize, bbox=None,
         for i in range(cfg.TREE.BRANCH_NUM):
             # print(imsize[i])
             if i < (cfg.TREE.BRANCH_NUM - 1):
-                re_img = transforms.Scale(imsize[i])(img)
+                re_img = transforms.Resize(imsize[i])(img)
             else:
                 re_img = img
             ret.append(normalize(re_img))
@@ -133,7 +133,7 @@ def load_bbox(self):
         #
         filename_bbox = {img_file[:-4]: [] for img_file in filenames}
         numImgs = len(filenames)
-        for i in xrange(0, numImgs):
+        for i in range(0, numImgs):
             # bbox = [x-left, y-top, width, height]
             bbox = df_bounding_boxes.iloc[i][1:].tolist()
 
@@ -251,7 +251,7 @@ def load_text_data(self, data_dir, split):
     def load_class_id(self, data_dir, total_num):
         if os.path.isfile(data_dir + '/class_info.pickle'):
             with open(data_dir + '/class_info.pickle', 'rb') as f:
-                class_id = pickle.load(f)
+                class_id = pickle.load(f, encoding="bytes")
         else:
             class_id = np.arange(total_num)
         return class_id
diff --git a/code/main.py b/code/main.py
index 934e7764..783ee730 100644
--- a/code/main.py
+++ b/code/main.py
@@ -39,14 +39,14 @@ def gen_example(wordtoix, algo):
     filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR)
     data_dic = {}
     with open(filepath, "r") as f:
-        filenames = f.read().decode('utf8').split('\n')
+        filenames = f.read().split('\n')
         for name in filenames:
             if len(name) == 0:
                 continue
             filepath = '%s/%s.txt' % (cfg.DATA_DIR, name)
             with open(filepath, "r") as f:
                 print('Load from:', name)
-                sentences = f.read().decode('utf8').split('\n')
+                sentences = f.read().split('\n')
                 # a list of indices for a sentence
                 captions = []
                 cap_lens = []
@@ -121,7 +121,7 @@ def gen_example(wordtoix, algo):
     # Get data loader
     imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))
     image_transform = transforms.Compose([
-        transforms.Scale(int(imsize * 76 / 64)),
+        transforms.Resize(int(imsize * 76 / 64)),
         transforms.RandomCrop(imsize),
         transforms.RandomHorizontalFlip()])
     dataset = TextDataset(cfg.DATA_DIR, split_dir,
diff --git a/code/miscc/config.py b/code/miscc/config.py
index 797319ba..e10990f0 100644
--- a/code/miscc/config.py
+++ b/code/miscc/config.py
@@ -70,9 +70,9 @@ def _merge_a_into_b(a, b):
     if type(a) is not edict:
         return
 
-    for k, v in a.iteritems():
+    for k, v in a.items():
         # a must specify keys that are in b
-        if not b.has_key(k):
+        if k not in b:
             raise KeyError('{} is not a valid config key'.format(k))
 
         # the types must match, too
diff --git a/code/miscc/losses.py b/code/miscc/losses.py
index b15612bf..90efd5c2 100644
--- a/code/miscc/losses.py
+++ b/code/miscc/losses.py
@@ -181,7 +181,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
             g_loss = cond_errG
         errG_total += g_loss
         # err_img = errG_total.data[0]
-        logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0])
+        logs += 'g_loss%d: %.2f ' % (i, g_loss.item())
 
         # Ranking loss
         if i == (numDs - 1):
@@ -202,7 +202,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
             # err_sent = err_sent + s_loss.data[0]
 
             errG_total += w_loss + s_loss
-            logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0])
+            logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item())
     return errG_total, logs
 
 
diff --git a/code/miscc/utils.py b/code/miscc/utils.py
index f131a365..04d16ebe 100644
--- a/code/miscc/utils.py
+++ b/code/miscc/utils.py
@@ -75,7 +75,8 @@ def build_super_images(real_imgs, captions, ixtoword,
 
 
     real_imgs = \
-        nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
+        nn.functional.interpolate(real_imgs,size=(vis_size, vis_size),
+                                  mode='bilinear', align_corners=False)
     # [-1, 1] --> [0, 1]
     real_imgs.add_(1).div_(2).mul_(255)
     real_imgs = real_imgs.data.numpy()
@@ -86,7 +87,8 @@ def build_super_images(real_imgs, captions, ixtoword,
     post_pad = np.zeros([pad_sze[1], pad_sze[2], 3])
     if lr_imgs is not None:
         lr_imgs = \
-            nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs)
+            nn.functional.interpolate(lr_imgs,size=(vis_size, vis_size),
+                                  mode='bilinear', align_corners=False)
         # [-1, 1] --> [0, 1]
         lr_imgs.add_(1).div_(2).mul_(255)
         lr_imgs = lr_imgs.data.numpy()
@@ -129,7 +131,8 @@ def build_super_images(real_imgs, captions, ixtoword,
             if (vis_size // att_sze) > 1:
                 one_map = \
                     skimage.transform.pyramid_expand(one_map, sigma=20,
-                                                     upscale=vis_size // att_sze)
+                                                     upscale=vis_size // att_sze,
+                                                     multichannel=True)
             row_beforeNorm.append(one_map)
             minV = one_map.min()
             maxV = one_map.max()
@@ -185,7 +188,8 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
                            dtype=np.uint8)
 
     real_imgs = \
-        nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
+        nn.functional.interpolate(real_imgs,size=(vis_size, vis_size),
+                                    mode='bilinear', align_corners=False)
     # [-1, 1] --> [0, 1]
     real_imgs.add_(1).div_(2).mul_(255)
     real_imgs = real_imgs.data.numpy()
@@ -228,7 +232,8 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
             if (vis_size // att_sze) > 1:
                 one_map = \
                     skimage.transform.pyramid_expand(one_map, sigma=20,
-                                                     upscale=vis_size // att_sze)
+                                                     upscale=vis_size // att_sze,
+                                                     multichannel=True)
             minV = one_map.min()
             maxV = one_map.max()
             one_map = (one_map - minV) / (maxV - minV)
@@ -286,12 +291,12 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
 def weights_init(m):
     classname = m.__class__.__name__
     if classname.find('Conv') != -1:
-        nn.init.orthogonal(m.weight.data, 1.0)
+        nn.init.orthogonal_(m.weight.data, 1.0)
     elif classname.find('BatchNorm') != -1:
         m.weight.data.normal_(1.0, 0.02)
         m.bias.data.fill_(0)
     elif classname.find('Linear') != -1:
-        nn.init.orthogonal(m.weight.data, 1.0)
+        nn.init.orthogonal_(m.weight.data, 1.0)
         if m.bias is not None:
             m.bias.data.fill_(0.0)
 
diff --git a/code/model.py b/code/model.py
index 09bde34e..9bc514b9 100644
--- a/code/model.py
+++ b/code/model.py
@@ -20,7 +20,19 @@ def forward(self, x):
         nc = x.size(1)
         assert nc % 2 == 0, 'channels dont divide 2!'
         nc = int(nc/2)
-        return x[:, :nc] * F.sigmoid(x[:, nc:])
+        return x[:, :nc] * torch.sigmoid(x[:, nc:])
+
+class Interpolate(nn.Module):
+    def __init__(self, scale_factor, mode, size=None):
+        super(Interpolate, self).__init__()
+        self.interp = nn.functional.interpolate
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.size = size
+
+    def forward(self, x):
+        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, size=self.size)
+        return x
 
 
 def conv1x1(in_planes, out_planes, bias=False):
@@ -38,7 +50,7 @@ def conv3x3(in_planes, out_planes):
 # Upsale the spatial size by a factor of 2
 def upBlock(in_planes, out_planes):
     block = nn.Sequential(
-        nn.Upsample(scale_factor=2, mode='nearest'),
+        Interpolate(scale_factor=2, mode='nearest'),
         conv3x3(in_planes, out_planes * 2),
         nn.BatchNorm2d(out_planes * 2),
         GLU())
@@ -207,7 +219,7 @@ def init_trainable_weights(self):
     def forward(self, x):
         features = None
         # --> fixed-size input: batch x 3 x 299 x 299
-        x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
+        x = nn.functional.interpolate(x,size=(299, 299), mode='bilinear', align_corners=False)
         # 299 x 299 x 3
         x = self.Conv2d_1a_3x3(x)
         # 149 x 149 x 32
diff --git a/code/pretrain_DAMSM.py b/code/pretrain_DAMSM.py
index 5f8b0ff9..1bce8167 100644
--- a/code/pretrain_DAMSM.py
+++ b/code/pretrain_DAMSM.py
@@ -235,7 +235,7 @@ def build_models():
     imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
     batch_size = cfg.TRAIN.BATCH_SIZE
     image_transform = transforms.Compose([
-        transforms.Scale(int(imsize * 76 / 64)),
+        transforms.Resize(int(imsize * 76 / 64)),
         transforms.RandomCrop(imsize),
         transforms.RandomHorizontalFlip()])
     dataset = TextDataset(cfg.DATA_DIR, 'train',
diff --git a/code/trainer.py b/code/trainer.py
index a6d4180f..1bbd2367 100644
--- a/code/trainer.py
+++ b/code/trainer.py
@@ -274,7 +274,7 @@ def train(self):
                     errD.backward()
                     optimizersD[i].step()
                     errD_total += errD
-                    D_logs += 'errD%d: %.2f ' % (i, errD.data[0])
+                    D_logs += 'errD%d: %.2f ' % (i, errD.item())
 
                 #######################################################
                 # (4) Update G network: maximize log(D(G(z)))
@@ -291,7 +291,7 @@ def train(self):
                                    words_embs, sent_emb, match_labels, cap_lens, class_ids)
                 kl_loss = KL_loss(mu, logvar)
                 errG_total += kl_loss
-                G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
+                G_logs += 'kl_loss: %.2f ' % kl_loss.item()
                 # backward and update parameters
                 errG_total.backward()
                 optimizerG.step()
@@ -318,7 +318,7 @@ def train(self):
             print('''[%d/%d][%d]
                   Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
                   % (epoch, self.max_epoch, self.num_batches,
-                     errD_total.data[0], errG_total.data[0],
+                     errD_total.item(), errG_total.item(),
                      end_t - start_t))
 
             if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:
@@ -370,8 +370,10 @@ def sampling(self, split_dir):
 
             batch_size = self.batch_size
             nz = cfg.GAN.Z_DIM
-            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
-            noise = noise.cuda()
+
+            with torch.no_grad():
+                noise = Variable(torch.FloatTensor(batch_size, nz))
+                noise = noise.cuda()
 
             model_dir = cfg.TRAIN.NET_G
             state_dict = \
@@ -463,14 +465,18 @@ def gen_example(self, data_dic):
 
                 batch_size = captions.shape[0]
                 nz = cfg.GAN.Z_DIM
-                captions = Variable(torch.from_numpy(captions), volatile=True)
-                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
 
-                captions = captions.cuda()
-                cap_lens = cap_lens.cuda()
+                with torch.no_grad():
+                    captions = Variable(torch.from_numpy(captions))
+                    cap_lens = Variable(torch.from_numpy(cap_lens))
+
+                    captions = captions.cuda()
+                    cap_lens = cap_lens.cuda()
+
                 for i in range(1):  # 16
-                    noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
-                    noise = noise.cuda()
+                    with torch.no_grad():
+                        noise = Variable(torch.FloatTensor(batch_size, nz))
+                        noise = noise.cuda()
                     #######################################################
                     # (1) Extract text embeddings
                     ######################################################
diff --git a/eval/eval.py b/eval/eval.py
index 48005f73..dc42f39c 100644
--- a/eval/eval.py
+++ b/eval/eval.py
@@ -54,16 +54,17 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi
     batch_size = captions.shape[0]
 
     nz = cfg.GAN.Z_DIM
-    captions = Variable(torch.from_numpy(captions), volatile=True)
-    cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
-    noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
+    with torch.no_grad():
+        captions = Variable(torch.from_numpy(captions))
+        cap_lens = Variable(torch.from_numpy(cap_lens))
+        noise = Variable(torch.FloatTensor(batch_size, nz))
 
     if cfg.CUDA:
         captions = captions.cuda()
         cap_lens = cap_lens.cuda()
         noise = noise.cuda()
 
-    
+
 
     #######################################################
     # (1) Extract text embeddings
@@ -71,7 +72,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi
     hidden = text_encoder.init_hidden(batch_size)
     words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
     mask = (captions == 0)
-        
+
 
     #######################################################
     # (2) Generate fake images
@@ -131,7 +132,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi
                         im = fake_imgs[k + 1].detach().cpu()
                     else:
                         im = fake_imgs[0].detach().cpu()
-                            
+
                     attn_maps = attention_maps[k]
                     att_sze = attn_maps.size(2)
 
@@ -152,7 +153,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi
                         urls.append(full_path % blob_name)
         if copies == 2:
             break
-    
+
     #print(len(urls), urls)
     return urls
 
@@ -223,7 +224,7 @@ def eval(caption):
 
 if __name__ == "__main__":
     caption = "the bird has a yellow crown and a black eyering that is round"
-    
+
     # load configuration
     #cfg_from_file('eval_bird.yml')
     # load word dictionaries
@@ -232,9 +233,9 @@ def eval(caption):
     text_encoder, netG = models(len(wordtoix))
     # load blob service
     blob_service = BlockBlobService(account_name='attgan', account_key='[REDACTED]')
-    
+
     t0 = time.time()
     urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service)
     t1 = time.time()
     print(t1-t0)
-    print(urls)
\ No newline at end of file
+    print(urls)
diff --git a/eval/miscc/utils.py b/eval/miscc/utils.py
index 13fc4739..1993b4c0 100644
--- a/eval/miscc/utils.py
+++ b/eval/miscc/utils.py
@@ -58,7 +58,7 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
                            dtype=np.uint8)
 
     real_imgs = \
-        nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
+        nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), mode='bilinear')
     # [-1, 1] --> [0, 1]
     real_imgs.add_(1).div_(2).mul_(255)
     real_imgs = real_imgs.data.numpy()
@@ -159,12 +159,12 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
 def weights_init(m):
     classname = m.__class__.__name__
     if classname.find('Conv') != -1:
-        nn.init.orthogonal(m.weight.data, 1.0)
+        nn.init.orthogonal_(m.weight.data, 1.0)
     elif classname.find('BatchNorm') != -1:
         m.weight.data.normal_(1.0, 0.02)
         m.bias.data.fill_(0)
     elif classname.find('Linear') != -1:
-        nn.init.orthogonal(m.weight.data, 1.0)
+        nn.init.orthogonal_(m.weight.data, 1.0)
         if m.bias is not None:
             m.bias.data.fill_(0.0)
 
diff --git a/eval/model.py b/eval/model.py
index 6d37ab3d..183769c1 100644
--- a/eval/model.py
+++ b/eval/model.py
@@ -26,7 +26,7 @@ def __init__(self, ntoken, ninput=300, drop_prob=0.5,
         self.drop_prob = drop_prob  # probability of an element to be zeroed
         self.nlayers = nlayers  # Number of recurrent layers
         self.bidirectional = bidirectional
-        
+
         if bidirectional:
             self.num_directions = 2
         else:
@@ -113,7 +113,7 @@ def __init__(self):
         nef = cfg.TEXT.EMBEDDING_DIM
         ncf = cfg.GAN.CONDITION_DIM
 
-        
+
         self.ca_net = CA_NET()
 
         if cfg.TREE.BRANCH_NUM > 0:
@@ -170,7 +170,7 @@ def __init__(self):
 
         self.t_dim = cfg.TEXT.EMBEDDING_DIM
         self.c_dim = cfg.GAN.CONDITION_DIM
-        
+
         self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
         self.relu = GLU()
 
@@ -204,7 +204,7 @@ def forward(self, x):
         nc = x.size(1)
         assert nc % 2 == 0, 'channels dont divide 2!'
         nc = int(nc/2)
-        return x[:, :nc] * F.sigmoid(x[:, nc:])
+        return x[:, :nc] * torch.sigmoid(x[:, nc:])
 
 
 def conv1x1(in_planes, out_planes, bias=False):
@@ -222,7 +222,7 @@ def conv3x3(in_planes, out_planes):
 # Upsale the spatial size by a factor of 2
 def upBlock(in_planes, out_planes):
     block = nn.Sequential(
-        nn.Upsample(scale_factor=2, mode='nearest'),
+        nn.functional.interpolate(scale_factor=2, mode='nearest'),
         conv3x3(in_planes, out_planes * 2),
         nn.BatchNorm2d(out_planes * 2),
         GLU())
@@ -304,7 +304,7 @@ def init_trainable_weights(self):
     def forward(self, x):
         features = None
         # --> fixed-size input: batch x 3 x 299 x 299
-        x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
+        x = nn.functional.interpolate(x,size=(299, 299), mode='bilinear')
         # 299 x 299 x 3
         x = self.Conv2d_1a_3x3(x)
         # 149 x 149 x 32

From 676657af666109c1f364b3eef4937519d1224f75 Mon Sep 17 00:00:00 2001
From: David Stap <dd.stap@gmail.com>
Date: Thu, 7 Mar 2019 17:01:40 +0100
Subject: [PATCH 2/3] Update README.md

---
 README.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index aa75fe10..87f97a1f 100644
--- a/README.md
+++ b/README.md
@@ -7,9 +7,9 @@ with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/c
 
 
 ### Dependencies
-python 2.7
+python 3.6+
 
-Pytorch
+Pytorch 1.0+
 
 In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
 - `python-dateutil`

From 78f9e1d10e7b89aa38b3d5847adc2d8849719844 Mon Sep 17 00:00:00 2001
From: David Stap <dd.stap@gmail.com>
Date: Tue, 19 Mar 2019 13:25:48 +0100
Subject: [PATCH 3/3] Update README.md

---
 README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/README.md b/README.md
index 87f97a1f..615421c0 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# AttnGAN
+# AttnGAN (Python 3, Pytorch 1.0)
 
 Pytorch implementation for reproducing AttnGAN results in the paper [AttnGAN: Fine-Grained Text to Image Generation
 with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf) by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research).