From 17b7c02263a19855f786e90ae6d9714988a37626 Mon Sep 17 00:00:00 2001 From: pegahk Date: Mon, 5 Aug 2019 12:44:45 -0400 Subject: [PATCH] implement highway norm --- ccai.yaml | 62 ++++++++++++ data_seg.py | 144 ++++++++++++++++++++++++++ networks.py | 76 +++++++++----- train.py | 55 +++++++--- trainer.py | 286 ++++++++++++---------------------------------------- utils.py | 12 +-- 6 files changed, 368 insertions(+), 267 deletions(-) create mode 100755 ccai.yaml create mode 100644 data_seg.py diff --git a/ccai.yaml b/ccai.yaml new file mode 100755 index 000000000..0675904b5 --- /dev/null +++ b/ccai.yaml @@ -0,0 +1,62 @@ +# Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). + +# logger options +image_save_iter: 5000 # How often do you want to save output images during training +image_display_iter: 200 # How often do you want to display output images during training +display_size: 16 # How many images do you want to display each time +snapshot_save_iter: 10000 # How often do you want to save trained models +log_iter: 1 # How often do you want to log the training stats + +# optimization options +max_iter: 400000 # maximum number of training iterations +batch_size: 1 # batch size +weight_decay: 0.0001 # weight decay +beta1: 0.5 # Adam parameter +beta2: 0.999 # Adam parameter +init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] +lr: 0.0001 # initial learning rate +lr_policy: step # learning rate scheduler +step_size: 100000 # how often to decay learning rate +gamma: 0.5 # how much to decay learning rate +gan_w: 1 # weight of adversarial loss +recon_x_w: 5 # weight of image reconstruction loss +recon_s_w: 1 # weight of style reconstruction loss +recon_c_w: 1 # weight of content reconstruction loss +recon_x_cyc_w: 5 # weight of explicit style augmented cycle consistency loss +vgg_w: 0 # weight of domain-invariant perceptual loss + +# model options +gen: + dim: 64 # number of filters in the bottommost layer + mlp_dim: 256 # number of filters in MLP + style_dim: 12 # length of style code + activ: relu # activation function [relu/lrelu/prelu/selu/tanh] + n_downsample: 2 # number of downsampling layers in content encoder + n_res: 4 # number of residual blocks in content encoder/decoder + pad_type: reflect # padding type [zero/reflect] +dis: + dim: 64 # number of filters in the bottommost layer + norm: none # normalization layer [none/bn/in/ln] + activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh] + n_layer: 4 # number of layers in D + gan_type: lsgan # GAN loss [lsgan/nsgan] + num_scales: 3 # number of scales + pad_type: reflect # padding type [zero/reflect] + +# data options +input_dim_a: 3 # number of image channels [1/3] +input_dim_b: 3 # number of image channels [1/3] +num_workers: 0 # number of data loading threads +new_size: 256 # first resize the shortest image side to this size +crop_image_height: 256 # random crop image of this height +crop_image_width: 256 # random crop image of this width + +data_folder_train_a: . +data_list_train_a: ./data/mapillary.txt +data_folder_test_a: . +data_list_test_a: ./data/mapillary.txt +data_folder_train_b: . +data_list_train_b: ./data/sorted_crf.txt +data_folder_test_b: . +data_list_test_b: ./data/sorted_crf.txt \ No newline at end of file diff --git a/data_seg.py b/data_seg.py new file mode 100644 index 000000000..860129ec8 --- /dev/null +++ b/data_seg.py @@ -0,0 +1,144 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" +import torch.utils.data as data +import os.path +import numpy as np +from torchvision import transforms + + +def default_loader(path, convert=True): + if convert: + return Image.open(path).convert('RGB') + else: + return Image.open(path) + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + seglist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath, segpath = line.strip().split(' ') + imlist.append(impath) + seglist.append(segpath) + if not os.path.exists(segpath): + print('not found', segpath) + + return imlist, seglist + + +class ImageFilelist(data.Dataset): + def __init__(self, root, flist, transform=None, + flist_reader=default_flist_reader, loader=default_loader): + self.root = root + self.imlist, self.seglist = flist_reader(flist) + self.transform = transform + self.loader = loader + + def __getitem__(self, index): + impath = self.imlist[index] + segpath = self.seglist[index] + img = self.loader(os.path.join(self.root, impath)) + seg = self.loader(os.path.join(self.root, segpath), False) + transform1 = transforms.Compose(self.transform) + transform2 = transforms.Compose(self.transform + [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + img = transform2(img) + seg = transform1(seg) * 255 + #print('here', impath) + return img, seg, impath + + + def __len__(self): + return len(self.imlist) + + +class ImageLabelFilelist(data.Dataset): + def __init__(self, root, flist, transform=None, + flist_reader=default_flist_reader, loader=default_loader): + self.root = root + self.imlist = flist_reader(os.path.join(self.root, flist)) + self.transform = transform + self.loader = loader + self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist]))) + self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} + self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist] + + def __getitem__(self, index): + impath, label = self.imgs[index] + img = self.loader(os.path.join(self.root, impath)) + if self.transform is not None: + img = self.transform(img) + return img, label + + def __len__(self): + return len(self.imgs) + +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = sorted(make_dataset(root)) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/networks.py b/networks.py index 099877f95..1caaffd67 100644 --- a/networks.py +++ b/networks.py @@ -66,6 +66,7 @@ def calc_dis_loss(self, input_fake, input_real): F.binary_cross_entropy(F.sigmoid(out1), all1)) else: assert 0, "Unsupported GAN type: {}".format(self.gan_type) + print('dis loss', loss.data) return loss def calc_gen_loss(self, input_fake): @@ -96,36 +97,43 @@ def __init__(self, input_dim, params): n_res = params['n_res'] activ = params['activ'] pad_type = params['pad_type'] - mlp_dim = params['mlp_dim'] + self.mlp_dim = params['mlp_dim'] + self.style_dim = style_dim # style encoder self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type) # content encoder self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type) - self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type) + self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, + self.style_dim, self.mlp_dim, res_norm='adain', activ=activ, pad_type=pad_type) # MLP to generate AdaIN parameters - self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ) + self.mlp1 = MLP(style_dim, self.get_num_adain_params(self.dec), self.mlp_dim, 3, norm='none', activ=activ) + self.mlp2 = MLP(style_dim, self.get_num_adain_params(self.dec), self.mlp_dim, 3, norm='none', activ=activ) - def forward(self, images): + def forward(self, images, masks): # reconstruct an image - content, style_fake = self.encode(images) - images_recon = self.decode(content, style_fake) + content, style_fg, style_bg = self.encode(images, masks) + images_recon = self.decode(content, style_fg, style_bg, masks) return images_recon - def encode(self, images): + def encode(self, images, masks): # encode an image to its content and style codes - style_fake = self.enc_style(images) + style_fg = self.enc_style(masks * images) + style_bg = self.enc_style((1-masks) * images, False) content = self.enc_content(images) - return content, style_fake + return content, style_fg, style_bg - def decode(self, content, style): + def decode(self, content, style_fg, style_bg, mask): # decode content and style codes to an image - adain_params = self.mlp(style) - self.assign_adain_params(adain_params, self.dec) - images = self.dec(content) - return images + adain_params1 = self.mlp1(style_fg) + adain_params2 = self.mlp2(style_bg) + self.assign_adain_params(adain_params1, self.dec) + images1 = self.dec(content, style_fg, style_bg, mask) + self.assign_adain_params(adain_params2, self.dec) + images2 = self.dec(content, style_bg, style_bg, mask) + return (mask * images1) + (1-mask) * images2 def assign_adain_params(self, adain_params, model): # assign the adain_params to the AdaIN layers in model @@ -148,7 +156,7 @@ def get_num_adain_params(self, model): class VAEGen(nn.Module): - # VAE architecture + # VAE architectureself def __init__(self, input_dim, params): super(VAEGen, self).__init__() dim = params['dim'] @@ -188,20 +196,30 @@ def decode(self, hiddens): class StyleEncoder(nn.Module): def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): super(StyleEncoder, self).__init__() - self.model = [] - self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] + self.shared = [] + self.spec_layers = [] + self.shared += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] for i in range(2): - self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.shared += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] dim *= 2 for i in range(n_downsample - 2): - self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] - self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling - self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] - self.model = nn.Sequential(*self.model) + self.spec_layers += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] + self.spec_layers = [nn.AdaptiveAvgPool2d(1)] # global average pooling + self.spec_layers += [nn.Conv2d(dim, style_dim, 1, 1, 0)] + + self.shared = nn.Sequential(*self.shared) + self.foreground = nn.Sequential(*self.spec_layers) + self.background = nn.Sequential(*self.spec_layers) + self.output_dim = dim - def forward(self, x): - return self.model(x) + def forward(self, x, isF=True): + x = self.shared(x) + if isF: + return self.foreground(x) + else: + return self.background(x) + class ContentEncoder(nn.Module): def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): @@ -220,13 +238,14 @@ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): def forward(self, x): return self.model(x) + class Decoder(nn.Module): - def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): + def __init__(self, n_upsample, n_res, dim, output_dim, style_dim, mlp_dim, res_norm='adain', activ='relu', pad_type='zero'): super(Decoder, self).__init__() self.model = [] # AdaIN residual blocks - self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] + self.model += [ResBlocks(n_res, dim, 'adain', activ, pad_type=pad_type)] # upsampling blocks for i in range(n_upsample): self.model += [nn.Upsample(scale_factor=2), @@ -236,8 +255,9 @@ def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ=' self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] self.model = nn.Sequential(*self.model) - def forward(self, x): - return self.model(x) + def forward(self, c_A, s_fB, s_bA, m_A): + return self.model(c_A) + ################################################################################## # Sequential Models diff --git a/train.py b/train.py index 0f24a2373..3bef4c678 100644 --- a/train.py +++ b/train.py @@ -5,9 +5,10 @@ from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer import argparse from torch.autograd import Variable -from trainer import MUNIT_Trainer, UNIT_Trainer +from trainer import MUNIT_Trainer import torch.backends.cudnn as cudnn import torch +import datetime try: from itertools import izip as zip except ImportError: # will be 3.x series @@ -16,6 +17,7 @@ import sys import tensorboardX import shutil +import pdb parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='configs/edges2handbags_folder.yaml', help='Path to the config file.') @@ -39,31 +41,49 @@ trainer = UNIT_Trainer(config) else: sys.exit("Only support MUNIT|UNIT") +r = 50 trainer.cuda() train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config) -train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda() -train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda() -test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda() -test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda() +train_display_images_a = torch.stack([train_loader_a.dataset[i+r][0] for i in range(display_size)]).cuda() +train_display_images_b = torch.stack([train_loader_b.dataset[i+r][0] for i in range(display_size)]).cuda() +test_display_images_a = torch.stack([test_loader_a.dataset[i+r][0] for i in range(display_size)]).cuda() +test_display_images_b = torch.stack([test_loader_b.dataset[i+r][0] for i in range(display_size)]).cuda() + +train_display_segs_a = torch.stack([train_loader_a.dataset[i+r][1] for i in range(display_size)]).cuda() +train_display_segs_b = torch.stack([train_loader_b.dataset[i+r][1] for i in range(display_size)]).cuda() +test_display_segs_a = torch.stack([test_loader_a.dataset[i+r][1] for i in range(display_size)]).cuda() +test_display_segs_b = torch.stack([test_loader_b.dataset[i+r][1] for i in range(display_size)]).cuda() # Setup logger and output folders model_name = os.path.splitext(os.path.basename(opts.config))[0] train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name)) -output_directory = os.path.join(opts.output_path + "/outputs", model_name) +dtstr = datetime.datetime.now().strftime("%y%m%d_%H%M%S") +output_directory = os.path.join(opts.output_path + "/outputs", model_name, dtstr) checkpoint_directory, image_directory = prepare_sub_folder(output_directory) shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder # Start training iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0 while True: - for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)): + for it, (imageseg_a, imagesseg_b) in enumerate(zip(train_loader_a, train_loader_b)): + images_a = imageseg_a[0] + segs_a = imageseg_a[1] + path_a = imageseg_a[2] + + images_b = imagesseg_b[0] + segs_b = imagesseg_b[1] + path_b = imagesseg_b[2] + trainer.update_learning_rate() images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach() + segs_a, segs_b = segs_a.cuda().detach(), segs_b.cuda().detach() with Timer("Elapsed time in update: %f"): # Main training code - trainer.dis_update(images_a, images_b, config) - trainer.gen_update(images_a, images_b, config) + trainer.dis_update(images_a, segs_a, images_b, segs_b, config) + if 0 == iterations % 3: + # trainer.gen_update(images_a, segs_a, images_b, segs_b, config, True) + trainer.gen_update(images_a, segs_a, images_b, segs_b, config) torch.cuda.synchronize() # Dump training stats in log file @@ -74,16 +94,25 @@ # Write images if (iterations + 1) % config['image_save_iter'] == 0: with torch.no_grad(): - test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b) - train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b) + test_image_outputs = trainer.sample(test_display_images_a, + test_display_segs_a, + test_display_images_b, + test_display_segs_b) + train_image_outputs = trainer.sample(train_display_images_a, + train_display_segs_a, + train_display_images_b, + train_display_segs_b) write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1)) write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1)) # HTML write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images') - if (iterations + 1) % config['image_display_iter'] == 0: + if ((iterations + 1) % config['image_display_iter'] == 0) or (iterations == 100): with torch.no_grad(): - image_outputs = trainer.sample(train_display_images_a, train_display_images_b) + image_outputs = trainer.sample(train_display_images_a, + train_display_segs_a, + train_display_images_b, + train_display_segs_b) write_2images(image_outputs, display_size, image_directory, 'train_current') # Save network weights diff --git a/trainer.py b/trainer.py index 2694a5ddc..e8e0baac8 100644 --- a/trainer.py +++ b/trainer.py @@ -49,7 +49,6 @@ def __init__(self, hyperparameters): self.vgg.eval() for param in self.vgg.parameters(): param.requires_grad = False - def recon_criterion(self, input, target): return torch.mean(torch.abs(input - target)) @@ -64,33 +63,42 @@ def forward(self, x_a, x_b): self.train() return x_ab, x_ba - def gen_update(self, x_a, x_b, hyperparameters): + def gen_update(self, x_a, m_A, x_b, m_B, hyperparameters, st=False): self.gen_opt.zero_grad() - s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) - s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) # encode - c_a, s_a_prime = self.gen_a.encode(x_a) - c_b, s_b_prime = self.gen_b.encode(x_b) + c_a, s_fA, s_bA = self.gen_a.encode(x_a, m_A) + c_b, s_fB, s_bB = self.gen_b.encode(x_b, m_B) + # decode (within domain) - x_a_recon = self.gen_a.decode(c_a, s_a_prime) - x_b_recon = self.gen_b.decode(c_b, s_b_prime) + x_aa = self.gen_a.decode(c_a, s_fA, s_bA, m_A) + x_bb = self.gen_b.decode(c_b, s_fB, s_bB, m_B) + # decode (cross domain) - x_ba = self.gen_a.decode(c_b, s_a) - x_ab = self.gen_b.decode(c_a, s_b) + x_ba = self.gen_a.decode(c_b, s_fA, s_bB, m_B) + x_ab = self.gen_b.decode(c_a, s_fB, s_bA, m_A) # encode again - c_b_recon, s_a_recon = self.gen_a.encode(x_ba) - c_a_recon, s_b_recon = self.gen_b.encode(x_ab) + c_b_recon, s_fA_recon, s_bA_recon = self.gen_a.encode(x_ba, m_B) + c_a_recon, s_fB_recon, s_bB_recon = self.gen_b.encode(x_ab, m_A) # decode again (if needed) - x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None - x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None + x_aba = self.gen_a.decode(c_a_recon, s_fA, s_bA, m_A) if hyperparameters['recon_x_cyc_w'] > 0 else None + x_bab = self.gen_b.decode(c_b_recon, s_fB, s_bB, m_B) if hyperparameters['recon_x_cyc_w'] > 0 else None # reconstruction loss - self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) - self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) - self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) - self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) + self.loss_gen_recon_x_a = self.recon_criterion(x_aa, x_a) + self.loss_gen_recon_x_b = self.recon_criterion(x_bb, x_b) + + self.loss_gen_recon_x_aB = self.recon_criterion((1-m_A)*x_ab, (1-m_A)*x_a) + self.loss_gen_recon_x_bB = self.recon_criterion((1-m_B)*x_ba, (1-m_B)*x_b) + + self.loss_gen_recon_s_bA = self.recon_criterion(s_bA_recon, s_bB) + self.loss_gen_recon_s_bB = self.recon_criterion(s_bB_recon, s_bA) + self.loss_gen_recon_s_fA = self.recon_criterion(s_fA_recon, s_fA) + self.loss_gen_recon_s_fB = self.recon_criterion(s_fB_recon, s_fB) + self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a) self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b) + + # translation self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 # GAN loss @@ -99,19 +107,27 @@ def gen_update(self, x_a, x_b, hyperparameters): # domain-invariant perceptual loss self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 + self.loss_disentangle = -self.recon_criterion(s_fA, s_fB) + #disentanglement loss + # total loss self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ hyperparameters['gan_w'] * self.loss_gen_adv_b + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ - hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_aB + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_fA + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bA + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \ hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ - hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ + hyperparameters['recon_x_w'] * self.loss_gen_recon_x_bB + \ + 10* hyperparameters['recon_s_w'] * self.loss_gen_recon_s_fB + \ + hyperparameters['recon_s_w'] * self.loss_gen_recon_s_bB + \ hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \ hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ - hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ - hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ - hyperparameters['vgg_w'] * self.loss_gen_vgg_b + hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + #hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ + #hyperparameters['vgg_w'] * self.loss_gen_vgg_b + print('gen_loss', self.loss_gen_total) self.loss_gen_total.backward() self.gen_opt.step() @@ -122,7 +138,7 @@ def compute_vgg_loss(self, vgg, img, target): target_fea = vgg(target_vgg) return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) - def sample(self, x_a, x_b): + def sample(self, x_a, m_a, x_b, m_b): self.eval() s_a1 = Variable(self.s_a) s_b1 = Variable(self.s_b) @@ -130,212 +146,44 @@ def sample(self, x_a, x_b): s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [] for i in range(x_a.size(0)): - c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0)) - c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0)) - x_a_recon.append(self.gen_a.decode(c_a, s_a_fake)) - x_b_recon.append(self.gen_b.decode(c_b, s_b_fake)) - x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0))) - x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0))) - x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0))) - x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0))) - x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) - x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) - x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) - self.train() - return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 - def dis_update(self, x_a, x_b, hyperparameters): - self.dis_opt.zero_grad() - s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda()) - s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda()) - # encode - c_a, _ = self.gen_a.encode(x_a) - c_b, _ = self.gen_b.encode(x_b) - # decode (cross domain) - x_ba = self.gen_a.decode(c_b, s_a) - x_ab = self.gen_b.decode(c_a, s_b) - # D loss - self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) - self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) - self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b - self.loss_dis_total.backward() - self.dis_opt.step() + m_A = m_a[i].unsqueeze(0) + m_B = m_b[i].unsqueeze(0) - def update_learning_rate(self): - if self.dis_scheduler is not None: - self.dis_scheduler.step() - if self.gen_scheduler is not None: - self.gen_scheduler.step() + c_a, s_fA, s_bA = self.gen_a.encode(x_a[i].unsqueeze(0), m_A) + c_b, s_fB, s_bB = self.gen_b.encode(x_b[i].unsqueeze(0), m_B) + x_a_recon.append(self.gen_a.decode(c_a, s_fA, s_bA, m_A)) + x_b_recon.append(self.gen_b.decode(c_b, s_fB, s_bB, m_B)) - def resume(self, checkpoint_dir, hyperparameters): - # Load generators - last_model_name = get_model_list(checkpoint_dir, "gen") - state_dict = torch.load(last_model_name) - self.gen_a.load_state_dict(state_dict['a']) - self.gen_b.load_state_dict(state_dict['b']) - iterations = int(last_model_name[-11:-3]) - # Load discriminators - last_model_name = get_model_list(checkpoint_dir, "dis") - state_dict = torch.load(last_model_name) - self.dis_a.load_state_dict(state_dict['a']) - self.dis_b.load_state_dict(state_dict['b']) - # Load optimizers - state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) - self.dis_opt.load_state_dict(state_dict['dis']) - self.gen_opt.load_state_dict(state_dict['gen']) - # Reinitilize schedulers - self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) - self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) - print('Resume from iteration %d' % iterations) - return iterations - - def save(self, snapshot_dir, iterations): - # Save generators, discriminators, and optimizers - gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) - dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) - opt_name = os.path.join(snapshot_dir, 'optimizer.pt') - torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) - torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) - torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) + x_ba2.append(self.gen_a.decode(c_b, s_fA, s_bB, m_B)) + x_ab2.append(self.gen_b.decode(c_a, s_fB, s_bA, m_A)) + x_BA1 = self.gen_a.decode(c_b, s_a1[i].unsqueeze(0), s_a2[i].unsqueeze(0), m_B) + x_AB1 = self.gen_b.decode(c_a, s_b1[i].unsqueeze(0), s_b2[i].unsqueeze(0), m_A) + x_AB1 = (1 * (1-m_A) * x_a[i].unsqueeze(0) + (0 * (1-m_A) * x_AB1)) + m_A * x_AB1 + x_BA1 = (1 * (1-m_B) * x_b[i].unsqueeze(0) + (0 * (1-m_B) * x_BA1)) + m_B * x_BA1 -class UNIT_Trainer(nn.Module): - def __init__(self, hyperparameters): - super(UNIT_Trainer, self).__init__() - lr = hyperparameters['lr'] - # Initiate the networks - self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a - self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b - self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a - self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b - self.instancenorm = nn.InstanceNorm2d(512, affine=False) + if 0 == i% 2: + x_ba1.append(x_BA1) + x_ab1.append(x_AB1) + else: + x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0), s_a2[i].unsqueeze(0), m_B)) + x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0), s_b2[i].unsqueeze(0), m_A)) - # Setup the optimizers - beta1 = hyperparameters['beta1'] - beta2 = hyperparameters['beta2'] - dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters()) - gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters()) - self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], - lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) - self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], - lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) - self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) - self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) - - # Network weight initialization - self.apply(weights_init(hyperparameters['init'])) - self.dis_a.apply(weights_init('gaussian')) - self.dis_b.apply(weights_init('gaussian')) - - # Load VGG model if needed - if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: - self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') - self.vgg.eval() - for param in self.vgg.parameters(): - param.requires_grad = False - - def recon_criterion(self, input, target): - return torch.mean(torch.abs(input - target)) - - def forward(self, x_a, x_b): - self.eval() - h_a, _ = self.gen_a.encode(x_a) - h_b, _ = self.gen_b.encode(x_b) - x_ba = self.gen_a.decode(h_b) - x_ab = self.gen_b.decode(h_a) - self.train() - return x_ab, x_ba - - def __compute_kl(self, mu): - # def _compute_kl(self, mu, sd): - # mu_2 = torch.pow(mu, 2) - # sd_2 = torch.pow(sd, 2) - # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0) - # return encoding_loss - mu_2 = torch.pow(mu, 2) - encoding_loss = torch.mean(mu_2) - return encoding_loss - - def gen_update(self, x_a, x_b, hyperparameters): - self.gen_opt.zero_grad() - # encode - h_a, n_a = self.gen_a.encode(x_a) - h_b, n_b = self.gen_b.encode(x_b) - # decode (within domain) - x_a_recon = self.gen_a.decode(h_a + n_a) - x_b_recon = self.gen_b.decode(h_b + n_b) - # decode (cross domain) - x_ba = self.gen_a.decode(h_b + n_b) - x_ab = self.gen_b.decode(h_a + n_a) - # encode again - h_b_recon, n_b_recon = self.gen_a.encode(x_ba) - h_a_recon, n_a_recon = self.gen_b.encode(x_ab) - # decode again (if needed) - x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None - x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None - - # reconstruction loss - self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) - self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) - self.loss_gen_recon_kl_a = self.__compute_kl(h_a) - self.loss_gen_recon_kl_b = self.__compute_kl(h_b) - self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a) - self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b) - self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon) - self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon) - # GAN loss - self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba) - self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab) - # domain-invariant perceptual loss - self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 - self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 - # total loss - self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ - hyperparameters['gan_w'] * self.loss_gen_adv_b + \ - hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ - hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \ - hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ - hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \ - hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \ - hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \ - hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \ - hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \ - hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ - hyperparameters['vgg_w'] * self.loss_gen_vgg_b - self.loss_gen_total.backward() - self.gen_opt.step() - - def compute_vgg_loss(self, vgg, img, target): - img_vgg = vgg_preprocess(img) - target_vgg = vgg_preprocess(target) - img_fea = vgg(img_vgg) - target_fea = vgg(target_vgg) - return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) - - def sample(self, x_a, x_b): - self.eval() - x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], [] - for i in range(x_a.size(0)): - h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0)) - h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0)) - x_a_recon.append(self.gen_a.decode(h_a)) - x_b_recon.append(self.gen_b.decode(h_b)) - x_ba.append(self.gen_a.decode(h_b)) - x_ab.append(self.gen_b.decode(h_a)) x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) - x_ba = torch.cat(x_ba) - x_ab = torch.cat(x_ab) + x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2) + x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2) self.train() - return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba + return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2 - def dis_update(self, x_a, x_b, hyperparameters): + def dis_update(self, x_a, m_a, x_b, m_b, hyperparameters): self.dis_opt.zero_grad() # encode - h_a, n_a = self.gen_a.encode(x_a) - h_b, n_b = self.gen_b.encode(x_b) + c_a, s_fA, s_bA = self.gen_a.encode(x_a, m_a) + c_b, s_fB, s_bB = self.gen_b.encode(x_b, m_b) # decode (cross domain) - x_ba = self.gen_a.decode(h_b + n_b) - x_ab = self.gen_b.decode(h_a + n_a) + x_ba = self.gen_a.decode(c_b, s_fA, s_bB, m_b) + x_ab = self.gen_b.decode(c_a, s_fB, s_bA, m_a) # D loss self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a) self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) @@ -378,4 +226,4 @@ def save(self, snapshot_dir, iterations): opt_name = os.path.join(snapshot_dir, 'optimizer.pt') torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name) torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name) - torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) + torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) \ No newline at end of file diff --git a/utils.py b/utils.py index 299bc48e2..ccaad6f24 100644 --- a/utils.py +++ b/utils.py @@ -8,7 +8,7 @@ from torch.autograd import Variable from torch.optim import lr_scheduler from torchvision import transforms -from data import ImageFilelist, ImageFolder +from data_seg import ImageFilelist, ImageFolder import torch import torch.nn as nn import os @@ -72,14 +72,12 @@ def get_all_data_loaders(conf): def get_data_loader_list(root, file_list, batch_size, train, new_size=None, height=256, width=256, num_workers=4, crop=True): - transform_list = [transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] + transform_list = [transforms.ToTensor()] transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list transform_list = [transforms.Resize(new_size)] + transform_list if new_size is not None else transform_list - transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list - transform = transforms.Compose(transform_list) - dataset = ImageFilelist(root, file_list, transform=transform) + #transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list + #transform = transforms.Compose(transform_list) + dataset = ImageFilelist(root, file_list, transform=transform_list) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train, drop_last=True, num_workers=num_workers) return loader