Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement highway norm #73

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions ccai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).

# logger options
image_save_iter: 5000 # How often do you want to save output images during training
image_display_iter: 200 # How often do you want to display output images during training
display_size: 16 # How many images do you want to display each time
snapshot_save_iter: 10000 # How often do you want to save trained models
log_iter: 1 # How often do you want to log the training stats

# optimization options
max_iter: 400000 # maximum number of training iterations
batch_size: 1 # batch size
weight_decay: 0.0001 # weight decay
beta1: 0.5 # Adam parameter
beta2: 0.999 # Adam parameter
init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
lr: 0.0001 # initial learning rate
lr_policy: step # learning rate scheduler
step_size: 100000 # how often to decay learning rate
gamma: 0.5 # how much to decay learning rate
gan_w: 1 # weight of adversarial loss
recon_x_w: 5 # weight of image reconstruction loss
recon_s_w: 1 # weight of style reconstruction loss
recon_c_w: 1 # weight of content reconstruction loss
recon_x_cyc_w: 5 # weight of explicit style augmented cycle consistency loss
vgg_w: 0 # weight of domain-invariant perceptual loss

# model options
gen:
dim: 64 # number of filters in the bottommost layer
mlp_dim: 256 # number of filters in MLP
style_dim: 12 # length of style code
activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
n_downsample: 2 # number of downsampling layers in content encoder
n_res: 4 # number of residual blocks in content encoder/decoder
pad_type: reflect # padding type [zero/reflect]
dis:
dim: 64 # number of filters in the bottommost layer
norm: none # normalization layer [none/bn/in/ln]
activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
n_layer: 4 # number of layers in D
gan_type: lsgan # GAN loss [lsgan/nsgan]
num_scales: 3 # number of scales
pad_type: reflect # padding type [zero/reflect]

# data options
input_dim_a: 3 # number of image channels [1/3]
input_dim_b: 3 # number of image channels [1/3]
num_workers: 0 # number of data loading threads
new_size: 256 # first resize the shortest image side to this size
crop_image_height: 256 # random crop image of this height
crop_image_width: 256 # random crop image of this width

data_folder_train_a: .
data_list_train_a: ./data/mapillary.txt
data_folder_test_a: .
data_list_test_a: ./data/mapillary.txt
data_folder_train_b: .
data_list_train_b: ./data/sorted_crf.txt
data_folder_test_b: .
data_list_test_b: ./data/sorted_crf.txt
144 changes: 144 additions & 0 deletions data_seg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import torch.utils.data as data
import os.path
import numpy as np
from torchvision import transforms


def default_loader(path, convert=True):
if convert:
return Image.open(path).convert('RGB')
else:
return Image.open(path)


def default_flist_reader(flist):
"""
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
"""
imlist = []
seglist = []
with open(flist, 'r') as rf:
for line in rf.readlines():
impath, segpath = line.strip().split(' ')
imlist.append(impath)
seglist.append(segpath)
if not os.path.exists(segpath):
print('not found', segpath)

return imlist, seglist


class ImageFilelist(data.Dataset):
def __init__(self, root, flist, transform=None,
flist_reader=default_flist_reader, loader=default_loader):
self.root = root
self.imlist, self.seglist = flist_reader(flist)
self.transform = transform
self.loader = loader

def __getitem__(self, index):
impath = self.imlist[index]
segpath = self.seglist[index]
img = self.loader(os.path.join(self.root, impath))
seg = self.loader(os.path.join(self.root, segpath), False)
transform1 = transforms.Compose(self.transform)
transform2 = transforms.Compose(self.transform + [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
img = transform2(img)
seg = transform1(seg) * 255
#print('here', impath)
return img, seg, impath


def __len__(self):
return len(self.imlist)


class ImageLabelFilelist(data.Dataset):
def __init__(self, root, flist, transform=None,
flist_reader=default_flist_reader, loader=default_loader):
self.root = root
self.imlist = flist_reader(os.path.join(self.root, flist))
self.transform = transform
self.loader = loader
self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist])))
self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist]

def __getitem__(self, index):
impath, label = self.imgs[index]
img = self.loader(os.path.join(self.root, impath))
if self.transform is not None:
img = self.transform(img)
return img, label

def __len__(self):
return len(self.imgs)

###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################

import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)

return images


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = sorted(make_dataset(root))
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img

def __len__(self):
return len(self.imgs)
76 changes: 48 additions & 28 deletions networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def calc_dis_loss(self, input_fake, input_real):
F.binary_cross_entropy(F.sigmoid(out1), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
print('dis loss', loss.data)
return loss

def calc_gen_loss(self, input_fake):
Expand Down Expand Up @@ -96,36 +97,43 @@ def __init__(self, input_dim, params):
n_res = params['n_res']
activ = params['activ']
pad_type = params['pad_type']
mlp_dim = params['mlp_dim']
self.mlp_dim = params['mlp_dim']
self.style_dim = style_dim

# style encoder
self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)

# content encoder
self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim,
self.style_dim, self.mlp_dim, res_norm='adain', activ=activ, pad_type=pad_type)

# MLP to generate AdaIN parameters
self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
self.mlp1 = MLP(style_dim, self.get_num_adain_params(self.dec), self.mlp_dim, 3, norm='none', activ=activ)
self.mlp2 = MLP(style_dim, self.get_num_adain_params(self.dec), self.mlp_dim, 3, norm='none', activ=activ)

def forward(self, images):
def forward(self, images, masks):
# reconstruct an image
content, style_fake = self.encode(images)
images_recon = self.decode(content, style_fake)
content, style_fg, style_bg = self.encode(images, masks)
images_recon = self.decode(content, style_fg, style_bg, masks)
return images_recon

def encode(self, images):
def encode(self, images, masks):
# encode an image to its content and style codes
style_fake = self.enc_style(images)
style_fg = self.enc_style(masks * images)
style_bg = self.enc_style((1-masks) * images, False)
content = self.enc_content(images)
return content, style_fake
return content, style_fg, style_bg

def decode(self, content, style):
def decode(self, content, style_fg, style_bg, mask):
# decode content and style codes to an image
adain_params = self.mlp(style)
self.assign_adain_params(adain_params, self.dec)
images = self.dec(content)
return images
adain_params1 = self.mlp1(style_fg)
adain_params2 = self.mlp2(style_bg)
self.assign_adain_params(adain_params1, self.dec)
images1 = self.dec(content, style_fg, style_bg, mask)
self.assign_adain_params(adain_params2, self.dec)
images2 = self.dec(content, style_bg, style_bg, mask)
return (mask * images1) + (1-mask) * images2

def assign_adain_params(self, adain_params, model):
# assign the adain_params to the AdaIN layers in model
Expand All @@ -148,7 +156,7 @@ def get_num_adain_params(self, model):


class VAEGen(nn.Module):
# VAE architecture
# VAE architectureself
def __init__(self, input_dim, params):
super(VAEGen, self).__init__()
dim = params['dim']
Expand Down Expand Up @@ -188,20 +196,30 @@ def decode(self, hiddens):
class StyleEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(StyleEncoder, self).__init__()
self.model = []
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
self.shared = []
self.spec_layers = []
self.shared += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
for i in range(2):
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.shared += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 2):
self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*self.model)
self.spec_layers += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.spec_layers = [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.spec_layers += [nn.Conv2d(dim, style_dim, 1, 1, 0)]

self.shared = nn.Sequential(*self.shared)
self.foreground = nn.Sequential(*self.spec_layers)
self.background = nn.Sequential(*self.spec_layers)

self.output_dim = dim

def forward(self, x):
return self.model(x)
def forward(self, x, isF=True):
x = self.shared(x)
if isF:
return self.foreground(x)
else:
return self.background(x)


class ContentEncoder(nn.Module):
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
Expand All @@ -220,13 +238,14 @@ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
def forward(self, x):
return self.model(x)


class Decoder(nn.Module):
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
def __init__(self, n_upsample, n_res, dim, output_dim, style_dim, mlp_dim, res_norm='adain', activ='relu', pad_type='zero'):
super(Decoder, self).__init__()

self.model = []
# AdaIN residual blocks
self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
self.model += [ResBlocks(n_res, dim, 'adain', activ, pad_type=pad_type)]
# upsampling blocks
for i in range(n_upsample):
self.model += [nn.Upsample(scale_factor=2),
Expand All @@ -236,8 +255,9 @@ def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
self.model = nn.Sequential(*self.model)

def forward(self, x):
return self.model(x)
def forward(self, c_A, s_fB, s_bA, m_A):
return self.model(c_A)


##################################################################################
# Sequential Models
Expand Down
Loading