Skip to content

Commit

Permalink
1. datasets are now configured automatically based on dataset_mode op…
Browse files Browse the repository at this point in the history
…tion. Please see data/__init__.py

2. The default options are overwritable by each dataset, although the current datasets are not using them.
3. [none] option was explicitly added to --resize_or_crop option. The image sizes are still adjusted to multiples of 4.
4. better visdom error display
5. pix2pix_model now sets more default values
  • Loading branch information
taesungp committed Jun 13, 2018
1 parent 08f4de1 commit 508c014
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 41 deletions.
62 changes: 41 additions & 21 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,60 @@
import importlib
import torch.utils.data
from data.base_data_loader import BaseDataLoader
from data.base_dataset import BaseDataset

def find_dataset_using_name(dataset_name):
# Given the option --dataset [datasetname],
# the file "datasets/datasetname_dataset.py"
# will be imported.
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

# In the file, the class called DatasetNameDataset() will
# be instantiated. It has to be a subclass of BaseDataset,
# and it is case-insensitive.
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls

if dataset is None:
print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
exit(0)

return dataset


def get_option_setter(dataset_name):
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options


def create_dataset(opt):
dataset = find_dataset_using_name(opt.dataset_mode)
instance = dataset()
instance.initialize(opt)
print("dataset [%s] was created" % (instance.name()))
return instance


def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader


def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'aligned':
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
elif opt.dataset_mode == 'unaligned':
from data.unaligned_dataset import UnalignedDataset
dataset = UnalignedDataset()
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset


## Wrapper class of Dataset class that performs
## multi-threaded data loading
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataset = create_dataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
Expand Down
4 changes: 4 additions & 0 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@


class AlignedDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
Expand Down
59 changes: 57 additions & 2 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ def __init__(self):
def name(self):
return 'BaseDataset'

@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
pass

def __len__(self):
return 0


def get_transform(opt):
transform_list = []
Expand All @@ -29,6 +36,11 @@ def get_transform(opt):
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'none':
transform_list.append(transforms.Lambda(
lambda img: __adjust(img)))
else:
raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
Expand All @@ -38,11 +50,54 @@ def get_transform(opt):
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)

# just modify the width and height to be multiple of 4
def __adjust(img):
ow, oh = img.size

# the size needs to be a multiple of this number,
# because going through generator network may change img size
# and eventually cause size mismatch error
mult = 4
if ow % mult == 0 and oh % mult == 0:
return img
w = (ow - 1) // mult
w = (w + 1) * mult
h = (oh - 1) // mult
h = (h + 1) * mult

if ow != w or oh != h:
__print_size_warning(ow, oh, w, h)

return img.resize((w, h), Image.BICUBIC)


def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):

# the size needs to be a multiple of this number,
# because going through generator network may change img size
# and eventually cause size mismatch error
mult = 4
assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
if (ow == target_width and oh % mult == 0):
return img
w = target_width
h = int(target_width * oh / ow)
target_height = int(target_width * oh / ow)
m = (target_height - 1) // mult
h = (m + 1) * mult

if target_height != h:
__print_size_warning(target_width, target_height, w, h)

return img.resize((w, h), Image.BICUBIC)


def __print_size_warning(ow, oh, w, h):
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True


4 changes: 4 additions & 0 deletions data/single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@


class SingleDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
Expand Down
4 changes: 4 additions & 0 deletions data/unaligned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@


class UnalignedDataset(BaseDataset):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser

def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
Expand Down
6 changes: 6 additions & 0 deletions models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ def name(self):

@staticmethod
def modify_commandline_options(parser, is_train=True):

# changing the default values to match the pix2pix paper
# (https://phillipi.github.io/pix2pix/)
parser.set_defaults(pool_size=0)
parser.set_defaults(no_lsgan=True)
parser.set_defaults(norm='batch')
parser.set_defaults(dataset_mode='aligned')
parser.set_defaults(which_model_netG='unet_256')
if is_train:
Expand Down
1 change: 0 additions & 1 deletion models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def name(self):
def modify_commandline_options(parser, is_train=True):
assert not is_train, 'TestModel cannot be used in train mode'
parser.set_defaults(dataset_mode='single')
parser.set_defaults(phase='test')

parser.add_argument('--model_suffix', type=str, default='')

Expand Down
9 changes: 6 additions & 3 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from util import util
import torch
import models
import data


class BaseOptions():
Expand Down Expand Up @@ -47,23 +48,25 @@ def initialize(self, parser):
return parser

def gather_options(self):

# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)

# get the basic options
opt, unknown = parser.parse_known_args()
opt, _ = parser.parse_known_args()

# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with the new defaults

# POSSIBLE FEATURE:
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)

self.parser = parser

Expand Down
4 changes: 4 additions & 0 deletions options/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ def initialize(self, parser):
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')

# To avoid cropping, the loadSize should be the same as fineSize
parser.set_defaults(loadSize=parser.get_default('fineSize'))

self.isTrain = False
return parser
2 changes: 1 addition & 1 deletion scripts/check_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan


echo 'pix2pix train (1 epoch) and test'
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --no_lsgan --norm batch --pool_size 0 --niter 1 --niter_decay 0 --save_latest_freq 400
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_L1 100 --dataset_mode aligned --no_lsgan --norm batch --pool_size 0 --niter 1 --niter_decay 0 --save_latest_freq 400
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --norm batch
4 changes: 2 additions & 2 deletions scripts/test_before_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def run_bash_command(command):
run_bash_command('python test.py --model test --dataroot ./datasets/mini --name horse2zebra_pretrained --no_dropout --how_many 1')

# test cyclegan
run_bash_command('python train.py --name temp --dataroot ./datasets/mini --niter 1 --niter_decay 0 --save_latest_freq 10 --display_freq 1')
run_bash_command('python train.py --name temp --dataroot ./datasets/mini --niter 1 --niter_decay 0 --save_latest_freq 10 --print_freq 1 --display_id -1')
run_bash_command('python test.py --name temp --dataroot ./datasets/mini --how_many 1')

# test pix2pix
run_bash_command('python train.py --model pix2pix --name temp --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10')
run_bash_command('python train.py --model pix2pix --name temp --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10 --display_id -1')
run_bash_command('python test.py --model pix2pix --name temp --dataroot ./datasets/mini_pix2pix --how_many 1 --which_direction BtoA')

28 changes: 17 additions & 11 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def __init__(self, opt):
def reset(self):
self.saved = False

def throw_visdom_connection_error(self):
print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
exit(1)

# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, save_result):
if self.display_id > 0: # show images in the browser
Expand Down Expand Up @@ -98,8 +102,7 @@ def display_current_results(self, visuals, epoch, save_result):
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
except ConnectionError:
print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
exit(1)
self.throw_visdom_connection_error()

else:
idx = 1
Expand Down Expand Up @@ -136,15 +139,18 @@ def plot_current_losses(self, epoch, counter_ratio, opt, losses):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
try:
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
except ConnectionError:
self.throw_visdom_connection_error()

# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, i, losses, t, t_data):
Expand Down

0 comments on commit 508c014

Please sign in to comment.