diff --git a/README.md b/README.md index aa75fe10..615421c0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# AttnGAN +# AttnGAN (Python 3, Pytorch 1.0) Pytorch implementation for reproducing AttnGAN results in the paper [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2018/papers/Xu_AttnGAN_Fine-Grained_Text_CVPR_2018_paper.pdf) by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research). @@ -7,9 +7,9 @@ with Attentional Generative Adversarial Networks](http://openaccess.thecvf.com/c ### Dependencies -python 2.7 +python 3.6+ -Pytorch +Pytorch 1.0+ In addition, please add the project folder to PYTHONPATH and `pip install` the following packages: - `python-dateutil` diff --git a/code/GlobalAttention.py b/code/GlobalAttention.py index 501fb720..c5322172 100644 --- a/code/GlobalAttention.py +++ b/code/GlobalAttention.py @@ -48,7 +48,7 @@ def func_attention(query, context, gamma1): attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper # --> batch*sourceL x queryL attn = attn.view(batch_size*sourceL, queryL) - attn = nn.Softmax()(attn) # Eq. (8) + attn = nn.Softmax(dim=1)(attn) # Eq. (8) # --> batch x sourceL x queryL attn = attn.view(batch_size, sourceL, queryL) @@ -57,7 +57,7 @@ def func_attention(query, context, gamma1): attn = attn.view(batch_size*queryL, sourceL) # Eq. (9) attn = attn * gamma1 - attn = nn.Softmax()(attn) + attn = nn.Softmax(dim=1)(attn) attn = attn.view(batch_size, queryL, sourceL) # --> batch x sourceL x queryL attnT = torch.transpose(attn, 1, 2).contiguous() @@ -73,7 +73,7 @@ class GlobalAttentionGeneral(nn.Module): def __init__(self, idf, cdf): super(GlobalAttentionGeneral, self).__init__() self.conv_context = conv1x1(cdf, idf) - self.sm = nn.Softmax() + self.sm = nn.Softmax(dim=1) self.mask = None def applyMask(self, mask): diff --git a/code/datasets.py b/code/datasets.py index 24ffdc4a..a6cb1f8b 100644 --- a/code/datasets.py +++ b/code/datasets.py @@ -80,7 +80,7 @@ def get_imgs(img_path, imsize, bbox=None, for i in range(cfg.TREE.BRANCH_NUM): # print(imsize[i]) if i < (cfg.TREE.BRANCH_NUM - 1): - re_img = transforms.Scale(imsize[i])(img) + re_img = transforms.Resize(imsize[i])(img) else: re_img = img ret.append(normalize(re_img)) @@ -133,7 +133,7 @@ def load_bbox(self): # filename_bbox = {img_file[:-4]: [] for img_file in filenames} numImgs = len(filenames) - for i in xrange(0, numImgs): + for i in range(0, numImgs): # bbox = [x-left, y-top, width, height] bbox = df_bounding_boxes.iloc[i][1:].tolist() @@ -251,7 +251,7 @@ def load_text_data(self, data_dir, split): def load_class_id(self, data_dir, total_num): if os.path.isfile(data_dir + '/class_info.pickle'): with open(data_dir + '/class_info.pickle', 'rb') as f: - class_id = pickle.load(f) + class_id = pickle.load(f, encoding="bytes") else: class_id = np.arange(total_num) return class_id diff --git a/code/main.py b/code/main.py index 934e7764..783ee730 100644 --- a/code/main.py +++ b/code/main.py @@ -39,14 +39,14 @@ def gen_example(wordtoix, algo): filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR) data_dic = {} with open(filepath, "r") as f: - filenames = f.read().decode('utf8').split('\n') + filenames = f.read().split('\n') for name in filenames: if len(name) == 0: continue filepath = '%s/%s.txt' % (cfg.DATA_DIR, name) with open(filepath, "r") as f: print('Load from:', name) - sentences = f.read().decode('utf8').split('\n') + sentences = f.read().split('\n') # a list of indices for a sentence captions = [] cap_lens = [] @@ -121,7 +121,7 @@ def gen_example(wordtoix, algo): # Get data loader imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) image_transform = transforms.Compose([ - transforms.Scale(int(imsize * 76 / 64)), + transforms.Resize(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, split_dir, diff --git a/code/miscc/config.py b/code/miscc/config.py index 797319ba..e10990f0 100644 --- a/code/miscc/config.py +++ b/code/miscc/config.py @@ -70,9 +70,9 @@ def _merge_a_into_b(a, b): if type(a) is not edict: return - for k, v in a.iteritems(): + for k, v in a.items(): # a must specify keys that are in b - if not b.has_key(k): + if k not in b: raise KeyError('{} is not a valid config key'.format(k)) # the types must match, too diff --git a/code/miscc/losses.py b/code/miscc/losses.py index b15612bf..90efd5c2 100644 --- a/code/miscc/losses.py +++ b/code/miscc/losses.py @@ -181,7 +181,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels, g_loss = cond_errG errG_total += g_loss # err_img = errG_total.data[0] - logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0]) + logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) # Ranking loss if i == (numDs - 1): @@ -202,7 +202,7 @@ def generator_loss(netsD, image_encoder, fake_imgs, real_labels, # err_sent = err_sent + s_loss.data[0] errG_total += w_loss + s_loss - logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0]) + logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) return errG_total, logs diff --git a/code/miscc/utils.py b/code/miscc/utils.py index f131a365..04d16ebe 100644 --- a/code/miscc/utils.py +++ b/code/miscc/utils.py @@ -75,7 +75,8 @@ def build_super_images(real_imgs, captions, ixtoword, real_imgs = \ - nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) + nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), + mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() @@ -86,7 +87,8 @@ def build_super_images(real_imgs, captions, ixtoword, post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) if lr_imgs is not None: lr_imgs = \ - nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs) + nn.functional.interpolate(lr_imgs,size=(vis_size, vis_size), + mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] lr_imgs.add_(1).div_(2).mul_(255) lr_imgs = lr_imgs.data.numpy() @@ -129,7 +131,8 @@ def build_super_images(real_imgs, captions, ixtoword, if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, - upscale=vis_size // att_sze) + upscale=vis_size // att_sze, + multichannel=True) row_beforeNorm.append(one_map) minV = one_map.min() maxV = one_map.max() @@ -185,7 +188,8 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, dtype=np.uint8) real_imgs = \ - nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) + nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), + mode='bilinear', align_corners=False) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() @@ -228,7 +232,8 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, - upscale=vis_size // att_sze) + upscale=vis_size // att_sze, + multichannel=True) minV = one_map.min() maxV = one_map.max() one_map = (one_map - minV) / (maxV - minV) @@ -286,12 +291,12 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) diff --git a/code/model.py b/code/model.py index 09bde34e..9bc514b9 100644 --- a/code/model.py +++ b/code/model.py @@ -20,7 +20,19 @@ def forward(self, x): nc = x.size(1) assert nc % 2 == 0, 'channels dont divide 2!' nc = int(nc/2) - return x[:, :nc] * F.sigmoid(x[:, nc:]) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode, size=None): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.size = size + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, size=self.size) + return x def conv1x1(in_planes, out_planes, bias=False): @@ -38,7 +50,7 @@ def conv3x3(in_planes, out_planes): # Upsale the spatial size by a factor of 2 def upBlock(in_planes, out_planes): block = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), + Interpolate(scale_factor=2, mode='nearest'), conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU()) @@ -207,7 +219,7 @@ def init_trainable_weights(self): def forward(self, x): features = None # --> fixed-size input: batch x 3 x 299 x 299 - x = nn.Upsample(size=(299, 299), mode='bilinear')(x) + x = nn.functional.interpolate(x,size=(299, 299), mode='bilinear', align_corners=False) # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32 diff --git a/code/pretrain_DAMSM.py b/code/pretrain_DAMSM.py index 5f8b0ff9..1bce8167 100644 --- a/code/pretrain_DAMSM.py +++ b/code/pretrain_DAMSM.py @@ -235,7 +235,7 @@ def build_models(): imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) batch_size = cfg.TRAIN.BATCH_SIZE image_transform = transforms.Compose([ - transforms.Scale(int(imsize * 76 / 64)), + transforms.Resize(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, 'train', diff --git a/code/trainer.py b/code/trainer.py index a6d4180f..1bbd2367 100644 --- a/code/trainer.py +++ b/code/trainer.py @@ -274,7 +274,7 @@ def train(self): errD.backward() optimizersD[i].step() errD_total += errD - D_logs += 'errD%d: %.2f ' % (i, errD.data[0]) + D_logs += 'errD%d: %.2f ' % (i, errD.item()) ####################################################### # (4) Update G network: maximize log(D(G(z))) @@ -291,7 +291,7 @@ def train(self): words_embs, sent_emb, match_labels, cap_lens, class_ids) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss - G_logs += 'kl_loss: %.2f ' % kl_loss.data[0] + G_logs += 'kl_loss: %.2f ' % kl_loss.item() # backward and update parameters errG_total.backward() optimizerG.step() @@ -318,7 +318,7 @@ def train(self): print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, - errD_total.data[0], errG_total.data[0], + errD_total.item(), errG_total.item(), end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: @@ -370,8 +370,10 @@ def sampling(self, split_dir): batch_size = self.batch_size nz = cfg.GAN.Z_DIM - noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) - noise = noise.cuda() + + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, nz)) + noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ @@ -463,14 +465,18 @@ def gen_example(self, data_dic): batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM - captions = Variable(torch.from_numpy(captions), volatile=True) - cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) - captions = captions.cuda() - cap_lens = cap_lens.cuda() + with torch.no_grad(): + captions = Variable(torch.from_numpy(captions)) + cap_lens = Variable(torch.from_numpy(cap_lens)) + + captions = captions.cuda() + cap_lens = cap_lens.cuda() + for i in range(1): # 16 - noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) - noise = noise.cuda() + with torch.no_grad(): + noise = Variable(torch.FloatTensor(batch_size, nz)) + noise = noise.cuda() ####################################################### # (1) Extract text embeddings ###################################################### diff --git a/eval/eval.py b/eval/eval.py index 48005f73..dc42f39c 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -54,16 +54,17 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM - captions = Variable(torch.from_numpy(captions), volatile=True) - cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) - noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) + with torch.no_grad(): + captions = Variable(torch.from_numpy(captions)) + cap_lens = Variable(torch.from_numpy(cap_lens)) + noise = Variable(torch.FloatTensor(batch_size, nz)) if cfg.CUDA: captions = captions.cuda() cap_lens = cap_lens.cuda() noise = noise.cuda() - + ####################################################### # (1) Extract text embeddings @@ -71,7 +72,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi hidden = text_encoder.init_hidden(batch_size) words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) - + ####################################################### # (2) Generate fake images @@ -131,7 +132,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() - + attn_maps = attention_maps[k] att_sze = attn_maps.size(2) @@ -152,7 +153,7 @@ def generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service, copi urls.append(full_path % blob_name) if copies == 2: break - + #print(len(urls), urls) return urls @@ -223,7 +224,7 @@ def eval(caption): if __name__ == "__main__": caption = "the bird has a yellow crown and a black eyering that is round" - + # load configuration #cfg_from_file('eval_bird.yml') # load word dictionaries @@ -232,9 +233,9 @@ def eval(caption): text_encoder, netG = models(len(wordtoix)) # load blob service blob_service = BlockBlobService(account_name='attgan', account_key='[REDACTED]') - + t0 = time.time() urls = generate(caption, wordtoix, ixtoword, text_encoder, netG, blob_service) t1 = time.time() print(t1-t0) - print(urls) \ No newline at end of file + print(urls) diff --git a/eval/miscc/utils.py b/eval/miscc/utils.py index 13fc4739..1993b4c0 100644 --- a/eval/miscc/utils.py +++ b/eval/miscc/utils.py @@ -58,7 +58,7 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, dtype=np.uint8) real_imgs = \ - nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) + nn.functional.interpolate(real_imgs,size=(vis_size, vis_size), mode='bilinear') # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() @@ -159,12 +159,12 @@ def build_super_images2(real_imgs, captions, cap_lens, ixtoword, def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: - nn.init.orthogonal(m.weight.data, 1.0) + nn.init.orthogonal_(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) diff --git a/eval/model.py b/eval/model.py index 6d37ab3d..183769c1 100644 --- a/eval/model.py +++ b/eval/model.py @@ -26,7 +26,7 @@ def __init__(self, ntoken, ninput=300, drop_prob=0.5, self.drop_prob = drop_prob # probability of an element to be zeroed self.nlayers = nlayers # Number of recurrent layers self.bidirectional = bidirectional - + if bidirectional: self.num_directions = 2 else: @@ -113,7 +113,7 @@ def __init__(self): nef = cfg.TEXT.EMBEDDING_DIM ncf = cfg.GAN.CONDITION_DIM - + self.ca_net = CA_NET() if cfg.TREE.BRANCH_NUM > 0: @@ -170,7 +170,7 @@ def __init__(self): self.t_dim = cfg.TEXT.EMBEDDING_DIM self.c_dim = cfg.GAN.CONDITION_DIM - + self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) self.relu = GLU() @@ -204,7 +204,7 @@ def forward(self, x): nc = x.size(1) assert nc % 2 == 0, 'channels dont divide 2!' nc = int(nc/2) - return x[:, :nc] * F.sigmoid(x[:, nc:]) + return x[:, :nc] * torch.sigmoid(x[:, nc:]) def conv1x1(in_planes, out_planes, bias=False): @@ -222,7 +222,7 @@ def conv3x3(in_planes, out_planes): # Upsale the spatial size by a factor of 2 def upBlock(in_planes, out_planes): block = nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), + nn.functional.interpolate(scale_factor=2, mode='nearest'), conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU()) @@ -304,7 +304,7 @@ def init_trainable_weights(self): def forward(self, x): features = None # --> fixed-size input: batch x 3 x 299 x 299 - x = nn.Upsample(size=(299, 299), mode='bilinear')(x) + x = nn.functional.interpolate(x,size=(299, 299), mode='bilinear') # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32