Skip to content

Commit

Permalink
Merge pull request junyanz#276 from SsnL/pytorch04
Browse files Browse the repository at this point in the history
PyTorch 0.4 compatibility
  • Loading branch information
junyanz authored May 23, 2018
2 parents 43585ae + 7abbeaa commit 235bc40
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 111 deletions.
26 changes: 15 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ This PyTorch implementation produces results comparable or better than our origi

<img src="https://phillipi.github.io/pix2pix/images/teaser_v3.png" width="900px"/>

#### [[EdgesCats Demo]](https://affinelayer.com/pixsrv/) [[pix2pix-tensorflow]](https://github.com/affinelayer/pix2pix-tensorflow)
Written by [Christopher Hesse](https://twitter.com/christophrhesse)
#### [[EdgesCats Demo]](https://affinelayer.com/pixsrv/) [[pix2pix-tensorflow]](https://github.com/affinelayer/pix2pix-tensorflow)
Written by [Christopher Hesse](https://twitter.com/christophrhesse)

<img src='imgs/edges2cats.jpg' width="600px"/>

If you use this code for your research, please cite:

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](https://people.eecs.berkeley.edu/~isola/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros)
In ICCV 2017. (* equal contributions)
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](https://people.eecs.berkeley.edu/~isola/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros)
In ICCV 2017. (* equal contributions)


Image-to-Image Translation with Conditional Adversarial Networks
[Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros)
Image-to-Image Translation with Conditional Adversarial Networks
[Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros)
In CVPR 2017.

## Course
Expand Down Expand Up @@ -83,6 +83,10 @@ python setup.py install
pip install visdom
pip install dominate
```
- Alternatively, all dependencies can be installed by
```bash
pip install -r requirements.txt
```
- Clone this repo:
```bash
git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Expand Down Expand Up @@ -175,7 +179,7 @@ Note that we specified `--which_direction BtoA` as Facades dataset's A to B dire

## Training/test Details
- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs.
- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs.
- Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
- Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`.
- Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count <int>` to specify a different starting epoch count.
Expand Down Expand Up @@ -244,12 +248,12 @@ If you use this code for your research, please cite our papers.
```

## Related Projects
[CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
[pix2pix](https://github.com/phillipi/pix2pix): Image-to-image translation with conditional adversarial nets
[CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
[pix2pix](https://github.com/phillipi/pix2pix): Image-to-image translation with conditional adversarial nets
[iGAN](https://github.com/junyanz/iGAN): Interactive Image Generation via Generative Adversarial Networks

## Cat Paper Collection
If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection:
If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection:
[[Github]](https://github.com/junyanz/CatPapers) [[Webpage]](https://people.eecs.berkeley.edu/~junyanz/cat/cat_papers.html)

## Acknowledgments
Expand Down
35 changes: 27 additions & 8 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
Expand All @@ -26,9 +26,11 @@ def set_input(self, input):
def forward(self):
pass

# used in test time, no backprop
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
pass
with torch.no_grad():
self.forward()

# get image paths
def get_image_paths(self):
Expand Down Expand Up @@ -57,7 +59,8 @@ def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = getattr(self, 'loss_' + name)
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret

# save models to the disk
Expand All @@ -74,17 +77,33 @@ def save_networks(self, which_epoch):
else:
torch.save(net.cpu().state_dict(), save_path)


def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

# load models from the disk
def load_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
net.module.load_state_dict(torch.load(save_path))
else:
net.load_state_dict(torch.load(save_path))
if isinstance(net, torch.nn.DataParallel):
net = net.module
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(save_path, map_location=str(self.device))
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)

# print network information
def print_networks(self, verbose):
Expand Down
22 changes: 3 additions & 19 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch.autograd import Variable
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
Expand Down Expand Up @@ -50,7 +49,7 @@ def initialize(self, opt):
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
Expand All @@ -71,25 +70,14 @@ def initialize(self, opt):

def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
if len(self.gpu_ids) > 0:
input_A = input_A.cuda(self.gpu_ids[0], async=True)
input_B = input_B.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.input_B = input_B
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']

def forward(self):
self.real_A = Variable(self.input_A)
self.real_B = Variable(self.input_B)

def test(self):
self.real_A = Variable(self.input_A, volatile=True)
self.fake_B = self.netG_A(self.real_A)
self.rec_A = self.netG_B(self.fake_B)

self.real_B = Variable(self.input_B, volatile=True)
self.fake_A = self.netG_B(self.real_B)
self.rec_B = self.netG_A(self.fake_A)

Expand Down Expand Up @@ -131,19 +119,15 @@ def backward_G(self):
self.loss_idt_B = 0

# GAN loss D_A(G_A(A))
self.fake_B = self.netG_A(self.real_A)
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)

# GAN loss D_B(G_B(B))
self.fake_A = self.netG_B(self.real_B)
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

# Forward cycle loss
self.rec_A = self.netG_B(self.fake_B)
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

# Backward cycle loss
self.rec_B = self.netG_A(self.fake_A)
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
Expand Down
44 changes: 14 additions & 30 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler

###############################################################################
Expand Down Expand Up @@ -42,20 +41,20 @@ def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal(m.weight.data, 0.0, gain)
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal(m.weight.data, gain=gain)
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal(m.weight.data, gain=gain)
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant(m.bias.data, 0.0)
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, gain)
init.constant(m.bias.data, 0.0)
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)

print('initialize network with %s' % init_type)
net.apply(init_func)
Expand All @@ -64,7 +63,7 @@ def init_func(m):
def init_net(net, init_type='normal', gpu_ids=[]):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.cuda(gpu_ids[0])
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
init_weights(net, init_type)
return net
Expand Down Expand Up @@ -114,36 +113,21 @@ def define_D(input_nc, ndf, which_model_netD,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()

def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
target_tensor = self.real_label
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
target_tensor = self.fake_label
return target_tensor.expand_as(input)

def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
Expand Down
20 changes: 3 additions & 17 deletions models/pix2pix_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
Expand Down Expand Up @@ -34,7 +33,7 @@ def initialize(self, opt):
if self.isTrain:
self.fake_AB_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionL1 = torch.nn.L1Loss()

# initialize optimizers
Expand All @@ -56,25 +55,12 @@ def initialize(self, opt):

def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
if len(self.gpu_ids) > 0:
input_A = input_A.cuda(self.gpu_ids[0], async=True)
input_B = input_B.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.input_B = input_B
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']

def forward(self):
self.real_A = Variable(self.input_A)
self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B)

# no backprop gradients
def test(self):
self.real_A = Variable(self.input_A, volatile=True)
self.fake_B = self.netG(self.real_A)
self.real_B = Variable(self.input_B, volatile=True)

def backward_D(self):
# Fake
Expand Down
8 changes: 2 additions & 6 deletions models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ def initialize(self, opt):

def set_input(self, input):
# we need to use single_dataset mode
input_A = input['A']
if len(self.gpu_ids) > 0:
input_A = input_A.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.real_A = input['A'].to(self.device)
self.image_paths = input['A_paths']

def test(self):
self.real_A = Variable(self.input_A, volatile=True)
def forward(self):
self.fake_B = self.netG(self.real_A)
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch>=0.4.0
torchvision>=0.2.1
dominate>=2.3.1
visdom>=0.1.8.3
5 changes: 2 additions & 3 deletions util/image_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
import torch
from torch.autograd import Variable


class ImagePool():
Expand All @@ -23,11 +22,11 @@ def query(self, images):
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return_images = torch.cat(return_images, 0)
return return_images
17 changes: 1 addition & 16 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,12 @@
import numpy as np
from PIL import Image
import os
from torch import is_tensor
from torch.autograd import Variable


# Converts a Tensor into a float
def tensor2float(input_error):
if is_tensor(input_error):
error = input_error[0]
elif isinstance(input_error, Variable):
error = input_error.data[0]
else:
error = input_error
return error


# Converts a Tensor into an image array (numpy)
# |imtype|: the desired type of the converted numpy array
def tensor2im(input_image, imtype=np.uint8):
if is_tensor(input_image):
image_tensor = input_image
elif isinstance(input_image, Variable):
if isinstance(input_image, torch.Tensor):
image_tensor = input_image.data
else:
return input_image
Expand Down
Loading

0 comments on commit 235bc40

Please sign in to comment.