diff --git a/models/base_model.py b/models/base_model.py index 0604c41ae5d..0564bf05671 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,6 +1,7 @@ import os import torch from collections import OrderedDict +from . import networks class BaseModel(): @@ -26,6 +27,22 @@ def set_input(self, input): def forward(self): pass + # load and print networks; create shedulars + def setup(self, opt): + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + self.load_networks(opt.which_epoch) + self.print_networks(opt.verbose) + + # make models eval mode during test time + def eval(self): + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + # used in test time, wrapping `forward` in no_grad() so we don't save # intermediate steps for backprop def test(self): @@ -77,7 +94,6 @@ def save_networks(self, which_epoch): else: torch.save(net.cpu().state_dict(), save_path) - def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer @@ -101,7 +117,7 @@ def load_networks(self, which_epoch): # GitHub source), you can remove str() on self.device state_dict = torch.load(save_path, map_location=str(self.device)) # patch InstanceNorm checkpoints prior to 0.4 - for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b0efafb2e16..333ff18f486 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -58,15 +58,10 @@ def initialize(self, opt): self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] - self.schedulers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) - for optimizer in self.optimizers: - self.schedulers.append(networks.get_scheduler(optimizer, opt)) - if not self.isTrain or opt.continue_train: - self.load_networks(opt.which_epoch) - self.print_networks(opt.verbose) + self.setup(opt) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 16d8e5b6046..db496844264 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -37,7 +37,6 @@ def initialize(self, opt): self.criterionL1 = torch.nn.L1Loss() # initialize optimizers - self.schedulers = [] self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) @@ -45,13 +44,8 @@ def initialize(self, opt): lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) - for optimizer in self.optimizers: - self.schedulers.append(networks.get_scheduler(optimizer, opt)) - if not self.isTrain or opt.continue_train: - self.load_networks(opt.which_epoch) - - self.print_networks(opt.verbose) + self.setup(opt) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' diff --git a/models/test_model.py b/models/test_model.py index 6f445dcf273..329ec541df1 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -1,4 +1,3 @@ -from torch.autograd import Variable from .base_model import BaseModel from . import networks @@ -23,8 +22,7 @@ def initialize(self, opt): opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) - self.load_networks(opt.which_epoch) - self.print_networks(opt.verbose) + self.setup(opt) def set_input(self, input): # we need to use single_dataset mode diff --git a/scripts/check_all.sh b/scripts/check_all.sh index 46b983e7674..f13a64766e0 100644 --- a/scripts/check_all.sh +++ b/scripts/check_all.sh @@ -1,17 +1,27 @@ set -ex +DOWNLOAD=${1} echo 'apply a pretrained cyclegan model' -bash pretrained_models/download_cyclegan_model.sh horse2zebra -bash ./datasets/download_cyclegan_dataset.sh horse2zebra +if [ ${DOWNLOAD} -eq 1 ] +then + bash pretrained_models/download_cyclegan_model.sh horse2zebra + bash ./datasets/download_cyclegan_dataset.sh horse2zebra +fi python test.py --dataroot datasets/horse2zebra/testA --checkpoints_dir ./checkpoints/ --name horse2zebra_pretrained --no_dropout --model test --dataset_mode single --loadSize 256 echo 'apply a pretrained pix2pix model' -bash pretrained_models/download_pix2pix_model.sh facades_label2photo -bash ./datasets/download_pix2pix_dataset.sh facades +if [ ${DOWNLOAD} -eq 1 ] +then + bash pretrained_models/download_pix2pix_model.sh facades_label2photo + bash ./datasets/download_pix2pix_dataset.sh facades +fi python test.py --dataroot ./datasets/facades/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch echo 'cyclegan train (1 epoch) and test' -bash ./datasets/download_cyclegan_dataset.sh maps +if [ ${DOWNLOAD} -eq 1 ] +then + bash ./datasets/download_cyclegan_dataset.sh maps +fi python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --no_dropout --niter 1 --niter_decay 0 --max_dataset_size 100 --save_latest_freq 100 python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout