diff --git a/README.md b/README.md index 486196b..14bf347 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Deep Convolution Generative Adversarial Networks +# Deep Convolutional Generative Adversarial Networks - DCGANs ![Generated images from noise on LFW ds after 300 epochs](images/lfw-300epochs.gif) diff --git a/app.py b/app.py index 5f4d954..13e49e9 100644 --- a/app.py +++ b/app.py @@ -60,8 +60,11 @@ def geneator_handler(path): zvector = torch.load(input_filepath) batchSize = zvector.size()[0] - checkpoint = request.form.get("ckp") or "netG_epoch_69.pth" - Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize) + checkpoint = request.form.get("ckp") or "netG_epoch_99.pth" + # GPU and cuda + # Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize, ngpu=1, cuda=True) + # CPU + Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize, ngpu=0) Generator.build_model() Generator.generate() return send_file(OUTPUT_PATH, mimetype='image/png') diff --git a/dcgan.py b/dcgan.py index ef37ee0..3c2f756 100644 --- a/dcgan.py +++ b/dcgan.py @@ -13,6 +13,7 @@ import torchvision.transforms as transforms import torchvision.utils as vutils from torch.autograd import Variable +from collections import OrderedDict # cpu_fix # Number of colours NC = 3 @@ -109,9 +110,8 @@ def __init__(self, NGF = int(ngf) # Load netG try: - torch.load(netG) + os.path.isfile(netG) self._netG = netG - pass except IOError as e: # Does not exist OR no read permissions print ("Unable to open netG file") @@ -133,7 +133,13 @@ def build_model(self): # Build and load the model self._model = _netG(self._ngpu) self._model.apply(weights_init) - self._model.load_state_dict(torch.load(self._netG)) + # Load Model ckp + if self._ngpu is not None and self._ngpu >= 1: + self._model.load_state_dict(torch.load(self._netG)) + else: + # Load GPU model on CPU + self._model.load_state_dict(torch.load(self._netG, map_location=lambda storage, loc: storage)) + self._model.cpu() # If provided use Zvector else create a random input normalized if self._zvector is not None: self._input = self._zvector diff --git a/generate.py b/generate.py index c931ab5..259c521 100644 --- a/generate.py +++ b/generate.py @@ -9,7 +9,7 @@ parser.add_argument('--outf', default='/output', help='folder to output images') parser.add_argument('--Zvector', help="path to Serialized Z vector") parser.add_argument('--cuda', action='store_true', help='enables cuda') -parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') +parser.add_argument('--ngpu', type=int, default=0, help='number of GPUs to use') opt = parser.parse_args() print(opt)