Skip to content

Commit

Permalink
simplify the code and add eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed May 24, 2018
1 parent 714a932 commit aa0b8a9
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
20 changes: 18 additions & 2 deletions models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
from collections import OrderedDict
from . import networks


class BaseModel():
Expand All @@ -26,6 +27,22 @@ def set_input(self, input):
def forward(self):
pass

# load and print networks; create shedulars
def setup(self, opt):
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]

if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)

# make models eval mode during test time
def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()

# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
Expand Down Expand Up @@ -77,7 +94,6 @@ def save_networks(self, which_epoch):
else:
torch.save(net.cpu().state_dict(), save_path)


def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
Expand All @@ -101,7 +117,7 @@ def load_networks(self, which_epoch):
# GitHub source), you can remove str() on self.device
state_dict = torch.load(save_path, map_location=str(self.device))
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)

Expand Down
7 changes: 1 addition & 6 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,10 @@ def initialize(self, opt):
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.schedulers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))

if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
self.setup(opt)

def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
Expand Down
8 changes: 1 addition & 7 deletions models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,15 @@ def initialize(self, opt):
self.criterionL1 = torch.nn.L1Loss()

# initialize optimizers
self.schedulers = []
self.optimizers = []
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))

if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)

self.print_networks(opt.verbose)
self.setup(opt)

def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
Expand Down
4 changes: 1 addition & 3 deletions models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from torch.autograd import Variable
from .base_model import BaseModel
from . import networks

Expand All @@ -23,8 +22,7 @@ def initialize(self, opt):
opt.norm, not opt.no_dropout,
opt.init_type,
self.gpu_ids)
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
self.setup(opt)

def set_input(self, input):
# we need to use single_dataset mode
Expand Down
20 changes: 15 additions & 5 deletions scripts/check_all.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
set -ex
DOWNLOAD=${1}
echo 'apply a pretrained cyclegan model'
bash pretrained_models/download_cyclegan_model.sh horse2zebra
bash ./datasets/download_cyclegan_dataset.sh horse2zebra
if [ ${DOWNLOAD} -eq 1 ]
then
bash pretrained_models/download_cyclegan_model.sh horse2zebra
bash ./datasets/download_cyclegan_dataset.sh horse2zebra
fi
python test.py --dataroot datasets/horse2zebra/testA --checkpoints_dir ./checkpoints/ --name horse2zebra_pretrained --no_dropout --model test --dataset_mode single --loadSize 256

echo 'apply a pretrained pix2pix model'
bash pretrained_models/download_pix2pix_model.sh facades_label2photo
bash ./datasets/download_pix2pix_dataset.sh facades
if [ ${DOWNLOAD} -eq 1 ]
then
bash pretrained_models/download_pix2pix_model.sh facades_label2photo
bash ./datasets/download_pix2pix_dataset.sh facades
fi
python test.py --dataroot ./datasets/facades/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch


echo 'cyclegan train (1 epoch) and test'
bash ./datasets/download_cyclegan_dataset.sh maps
if [ ${DOWNLOAD} -eq 1 ]
then
bash ./datasets/download_cyclegan_dataset.sh maps
fi
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --no_dropout --niter 1 --niter_decay 0 --max_dataset_size 100 --save_latest_freq 100
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout

Expand Down

0 comments on commit aa0b8a9

Please sign in to comment.