Skip to content

Commit

Permalink
Fix gpu to cpu model, fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
ReDeiPirati committed Sep 12, 2017
1 parent 71f102b commit 1703cf4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Deep Convolution Generative Adversarial Networks
# Deep Convolutional Generative Adversarial Networks - DCGANs

![Generated images from noise on LFW ds after 300 epochs](images/lfw-300epochs.gif)

Expand Down
7 changes: 5 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ def geneator_handler(path):
zvector = torch.load(input_filepath)
batchSize = zvector.size()[0]

checkpoint = request.form.get("ckp") or "netG_epoch_69.pth"
Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize)
checkpoint = request.form.get("ckp") or "netG_epoch_99.pth"
# GPU and cuda
# Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize, ngpu=1, cuda=True)
# CPU
Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize, ngpu=0)
Generator.build_model()
Generator.generate()
return send_file(OUTPUT_PATH, mimetype='image/png')
Expand Down
12 changes: 9 additions & 3 deletions dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from collections import OrderedDict # cpu_fix

# Number of colours
NC = 3
Expand Down Expand Up @@ -109,9 +110,8 @@ def __init__(self,
NGF = int(ngf)
# Load netG
try:
torch.load(netG)
os.path.isfile(netG)
self._netG = netG
pass
except IOError as e:
# Does not exist OR no read permissions
print ("Unable to open netG file")
Expand All @@ -133,7 +133,13 @@ def build_model(self):
# Build and load the model
self._model = _netG(self._ngpu)
self._model.apply(weights_init)
self._model.load_state_dict(torch.load(self._netG))
# Load Model ckp
if self._ngpu is not None and self._ngpu >= 1:
self._model.load_state_dict(torch.load(self._netG))
else:
# Load GPU model on CPU
self._model.load_state_dict(torch.load(self._netG, map_location=lambda storage, loc: storage))
self._model.cpu()
# If provided use Zvector else create a random input normalized
if self._zvector is not None:
self._input = self._zvector
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
parser.add_argument('--outf', default='/output', help='folder to output images')
parser.add_argument('--Zvector', help="path to Serialized Z vector")
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--ngpu', type=int, default=0, help='number of GPUs to use')
opt = parser.parse_args()
print(opt)

Expand Down

0 comments on commit 1703cf4

Please sign in to comment.