From 9ba91fa13cbb1e7bc4069e46469b34abb5ca4869 Mon Sep 17 00:00:00 2001 From: SsnL Date: Tue, 22 May 2018 14:10:11 -0400 Subject: [PATCH 1/3] update to 0.4 --- README.md | 26 +++++++++++++---------- models/base_model.py | 34 +++++++++++++++++++++++------- models/cycle_gan_model.py | 25 +++++++--------------- models/networks.py | 44 +++++++++++++-------------------------- models/pix2pix_model.py | 23 +++++++------------- models/test_model.py | 9 ++++---- requirements.txt | 4 ++++ util/image_pool.py | 5 ++--- util/util.py | 17 +-------------- util/visualizer.py | 2 +- 10 files changed, 81 insertions(+), 108 deletions(-) create mode 100644 requirements.txt diff --git a/README.md b/README.md index 8965c3544c9..ae967a0be50 100644 --- a/README.md +++ b/README.md @@ -19,20 +19,20 @@ This PyTorch implementation produces results comparable or better than our origi -#### [[EdgesCats Demo]](https://affinelayer.com/pixsrv/) [[pix2pix-tensorflow]](https://github.com/affinelayer/pix2pix-tensorflow) -Written by [Christopher Hesse](https://twitter.com/christophrhesse) +#### [[EdgesCats Demo]](https://affinelayer.com/pixsrv/) [[pix2pix-tensorflow]](https://github.com/affinelayer/pix2pix-tensorflow) +Written by [Christopher Hesse](https://twitter.com/christophrhesse) If you use this code for your research, please cite: -Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks -[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](https://people.eecs.berkeley.edu/~isola/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros) -In ICCV 2017. (* equal contributions) +Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks +[Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](https://people.eecs.berkeley.edu/~isola/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros) +In ICCV 2017. (* equal contributions) -Image-to-Image Translation with Conditional Adversarial Networks -[Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros) +Image-to-Image Translation with Conditional Adversarial Networks +[Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros) In CVPR 2017. ## Course @@ -83,6 +83,10 @@ python setup.py install pip install visdom pip install dominate ``` +- Alternatively, all dependencies can be installed by +```bash +pip install -r requirements.txt +``` - Clone this repo: ```bash git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix @@ -175,7 +179,7 @@ Note that we specified `--which_direction BtoA` as Facades dataset's A to B dire ## Training/test Details - Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags. -- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs. +- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs. - Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`. - Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. - Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count ` to specify a different starting epoch count. @@ -244,12 +248,12 @@ If you use this code for your research, please cite our papers. ``` ## Related Projects -[CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks -[pix2pix](https://github.com/phillipi/pix2pix): Image-to-image translation with conditional adversarial nets +[CycleGAN](https://github.com/junyanz/CycleGAN): Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks +[pix2pix](https://github.com/phillipi/pix2pix): Image-to-image translation with conditional adversarial nets [iGAN](https://github.com/junyanz/iGAN): Interactive Image Generation via Generative Adversarial Networks ## Cat Paper Collection -If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection: +If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper Collection: [[Github]](https://github.com/junyanz/CatPapers) [[Webpage]](https://people.eecs.berkeley.edu/~junyanz/cat/cat_papers.html) ## Acknowledgments diff --git a/models/base_model.py b/models/base_model.py index 504f0431632..37b91eb0ed0 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -11,7 +11,7 @@ def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain - self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) if opt.resize_or_crop != 'scale_width': torch.backends.cudnn.benchmark = True @@ -26,9 +26,11 @@ def set_input(self, input): def forward(self): pass - # used in test time, no backprop + # used in test time, wrapping `forward` in no_grad() so we don't save + # intermediate steps for backprop def test(self): - pass + with torch.no_grad(): + self.forward() # get image paths def get_image_paths(self): @@ -57,7 +59,7 @@ def get_current_losses(self): errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): - errors_ret[name] = getattr(self, 'loss_' + name) + errors_ret[name] = getattr(self, 'loss_' + name).item() return errors_ret # save models to the disk @@ -74,6 +76,17 @@ def save_networks(self, which_epoch): else: torch.save(net.cpu().state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + # load models from the disk def load_networks(self, which_epoch): for name in self.model_names: @@ -81,10 +94,15 @@ def load_networks(self, which_epoch): save_filename = '%s_net_%s.pth' % (which_epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) - if len(self.gpu_ids) > 0 and torch.cuda.is_available(): - net.module.load_state_dict(torch.load(save_path)) - else: - net.load_state_dict(torch.load(save_path)) + if isinstance(net, torch.nn.DataParallel): + net = net.module + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(save_path, map_location=str(self.device)) + # patch InstanceNorm checkpoints prior to 0.4 + for key in state_dict: + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) # print network information def print_networks(self, verbose): diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 0d6d09ae02a..a0409285f8a 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -1,5 +1,4 @@ import torch -from torch.autograd import Variable import itertools from util.image_pool import ImagePool from .base_model import BaseModel @@ -50,7 +49,7 @@ def initialize(self, opt): self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions - self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers @@ -71,25 +70,19 @@ def initialize(self, opt): def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' - input_A = input['A' if AtoB else 'B'] - input_B = input['B' if AtoB else 'A'] + real_A = input['A' if AtoB else 'B'] + real_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: - input_A = input_A.cuda(self.gpu_ids[0], async=True) - input_B = input_B.cuda(self.gpu_ids[0], async=True) - self.input_A = input_A - self.input_B = input_B + real_A = real_A.to(self.device) + real_B = real_B.to(self.device) + self.real_A = real_A + self.real_B = real_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): - self.real_A = Variable(self.input_A) - self.real_B = Variable(self.input_B) - - def test(self): - self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) - self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) @@ -131,19 +124,15 @@ def backward_G(self): self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) - self.fake_B = self.netG_A(self.real_A) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) - self.fake_A = self.netG_B(self.real_B) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss - self.rec_A = self.netG_B(self.fake_B) self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss - self.rec_B = self.netG_A(self.fake_A) self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B diff --git a/models/networks.py b/models/networks.py index 0f3fbb2b38c..20167bf5c1d 100644 --- a/models/networks.py +++ b/models/networks.py @@ -2,7 +2,6 @@ import torch.nn as nn from torch.nn import init import functools -from torch.autograd import Variable from torch.optim import lr_scheduler ############################################################################### @@ -42,20 +41,20 @@ def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': - init.normal(m.weight.data, 0.0, gain) + init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': - init.xavier_normal(m.weight.data, gain=gain) + init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': - init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': - init.orthogonal(m.weight.data, gain=gain) + init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: - init.constant(m.bias.data, 0.0) + init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: - init.normal(m.weight.data, 1.0, gain) - init.constant(m.bias.data, 0.0) + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) @@ -64,7 +63,7 @@ def init_func(m): def init_net(net, init_type='normal', gpu_ids=[]): if len(gpu_ids) > 0: assert(torch.cuda.is_available()) - net.cuda(gpu_ids[0]) + net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) init_weights(net, init_type) return net @@ -114,36 +113,21 @@ def define_D(input_nc, ndf, which_model_netD, # but it abstracts away the need to create the target label tensor # that has the same size as the input class GANLoss(nn.Module): - def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, - tensor=torch.FloatTensor): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): super(GANLoss, self).__init__() - self.real_label = target_real_label - self.fake_label = target_fake_label - self.real_label_var = None - self.fake_label_var = None - self.Tensor = tensor + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): - target_tensor = None if target_is_real: - create_label = ((self.real_label_var is None) or - (self.real_label_var.numel() != input.numel())) - if create_label: - real_tensor = self.Tensor(input.size()).fill_(self.real_label) - self.real_label_var = Variable(real_tensor, requires_grad=False) - target_tensor = self.real_label_var + target_tensor = self.real_label else: - create_label = ((self.fake_label_var is None) or - (self.fake_label_var.numel() != input.numel())) - if create_label: - fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) - self.fake_label_var = Variable(fake_tensor, requires_grad=False) - target_tensor = self.fake_label_var - return target_tensor + target_tensor = self.fake_label + return target_tensor.expand_as(input) def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 4ab014d83e2..ccb76468a83 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -1,5 +1,4 @@ import torch -from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from . import networks @@ -34,7 +33,7 @@ def initialize(self, opt): if self.isTrain: self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions - self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionL1 = torch.nn.L1Loss() # initialize optimizers @@ -56,25 +55,17 @@ def initialize(self, opt): def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' - input_A = input['A' if AtoB else 'B'] - input_B = input['B' if AtoB else 'A'] + real_A = input['A' if AtoB else 'B'] + real_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: - input_A = input_A.cuda(self.gpu_ids[0], async=True) - input_B = input_B.cuda(self.gpu_ids[0], async=True) - self.input_A = input_A - self.input_B = input_B + real_A = real_A.to(self.device) + real_B = real_B.to(self.device) + self.real_A = real_A + self.real_B = real_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): - self.real_A = Variable(self.input_A) self.fake_B = self.netG(self.real_A) - self.real_B = Variable(self.input_B) - - # no backprop gradients - def test(self): - self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG(self.real_A) - self.real_B = Variable(self.input_B, volatile=True) def backward_D(self): # Fake diff --git a/models/test_model.py b/models/test_model.py index 6fddd1a5124..ccc3c353be0 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -28,12 +28,11 @@ def initialize(self, opt): def set_input(self, input): # we need to use single_dataset mode - input_A = input['A'] + real_A = input['A'] if len(self.gpu_ids) > 0: - input_A = input_A.cuda(self.gpu_ids[0], async=True) - self.input_A = input_A + real_A = real_A.to(self.device) + self.real_A = real_A self.image_paths = input['A_paths'] - def test(self): - self.real_A = Variable(self.input_A, volatile=True) + def forward(self): self.fake_B = self.netG(self.real_A) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000000..072d027a2b0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch>=0.4.0 +torchvision>=0.2.1 +dominate>=2.3.1 +visdom>=0.1.8.3 diff --git a/util/image_pool.py b/util/image_pool.py index ad8ac261dc1..52413e0f8a4 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -1,6 +1,5 @@ import random import torch -from torch.autograd import Variable class ImagePool(): @@ -23,11 +22,11 @@ def query(self, images): else: p = random.uniform(0, 1) if p > 0.5: - random_id = random.randint(0, self.pool_size - 1) + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: return_images.append(image) - return_images = Variable(torch.cat(return_images, 0)) + return_images = torch.cat(return_images, 0) return return_images diff --git a/util/util.py b/util/util.py index 23440f8685e..ba7b083ca18 100644 --- a/util/util.py +++ b/util/util.py @@ -3,27 +3,12 @@ import numpy as np from PIL import Image import os -from torch import is_tensor -from torch.autograd import Variable - - -# Converts a Tensor into a float -def tensor2float(input_error): - if is_tensor(input_error): - error = input_error[0] - elif isinstance(input_error, Variable): - error = input_error.data[0] - else: - error = input_error - return error # Converts a Tensor into an image array (numpy) # |imtype|: the desired type of the converted numpy array def tensor2im(input_image, imtype=np.uint8): - if is_tensor(input_image): - image_tensor = input_image - elif isinstance(input_image, Variable): + if isinstance(input_image, torch.Tensor): image_tensor = input_image.data else: return input_image diff --git a/util/visualizer.py b/util/visualizer.py index aeff7dfba34..dcd7d7dd69a 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -104,7 +104,7 @@ def plot_current_losses(self, epoch, counter_ratio, opt, losses): if not hasattr(self, 'plot_data'): self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} self.plot_data['X'].append(epoch + counter_ratio) - self.plot_data['Y'].append([util.tensor2float(losses[k]) for k in self.plot_data['legend']]) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) self.vis.line( X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), From e1112e7e66f50cbee322877e46483e9ea32d20d2 Mon Sep 17 00:00:00 2001 From: Tongzhou Wang Date: Tue, 22 May 2018 15:41:29 -0400 Subject: [PATCH 2/3] fixes --- models/base_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index 37b91eb0ed0..0604c41ae5d 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -59,7 +59,8 @@ def get_current_losses(self): errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): - errors_ret[name] = getattr(self, 'loss_' + name).item() + # float(...) works for both scalar tensor and float number + errors_ret[name] = float(getattr(self, 'loss_' + name)) return errors_ret # save models to the disk @@ -100,7 +101,7 @@ def load_networks(self, which_epoch): # GitHub source), you can remove str() on self.device state_dict = torch.load(save_path, map_location=str(self.device)) # patch InstanceNorm checkpoints prior to 0.4 - for key in state_dict: + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) From 7abbeaa145581496fb43276400aab2ae4387d463 Mon Sep 17 00:00:00 2001 From: Tongzhou Wang Date: Tue, 22 May 2018 15:47:51 -0400 Subject: [PATCH 3/3] simplify code --- models/cycle_gan_model.py | 9 ++------- models/pix2pix_model.py | 9 ++------- models/test_model.py | 5 +---- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index a0409285f8a..b0efafb2e16 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -70,13 +70,8 @@ def initialize(self, opt): def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' - real_A = input['A' if AtoB else 'B'] - real_B = input['B' if AtoB else 'A'] - if len(self.gpu_ids) > 0: - real_A = real_A.to(self.device) - real_B = real_B.to(self.device) - self.real_A = real_A - self.real_B = real_B + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index ccb76468a83..16d8e5b6046 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -55,13 +55,8 @@ def initialize(self, opt): def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' - real_A = input['A' if AtoB else 'B'] - real_B = input['B' if AtoB else 'A'] - if len(self.gpu_ids) > 0: - real_A = real_A.to(self.device) - real_B = real_B.to(self.device) - self.real_A = real_A - self.real_B = real_B + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): diff --git a/models/test_model.py b/models/test_model.py index ccc3c353be0..6f445dcf273 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -28,10 +28,7 @@ def initialize(self, opt): def set_input(self, input): # we need to use single_dataset mode - real_A = input['A'] - if len(self.gpu_ids) > 0: - real_A = real_A.to(self.device) - self.real_A = real_A + self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] def forward(self):