-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Complete porting of dcgan on FloydHub
- Loading branch information
1 parent
10bc58c
commit 4d2973f
Showing
9 changed files
with
423 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"family_id": "AZRhWDBRyMHUaEtPsqtWpQ", "name": "dcgan"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
|
||
# Directories and files to ignore when uploading code to floyd | ||
|
||
.git | ||
.eggs | ||
eggs | ||
lib | ||
lib64 | ||
parts | ||
sdist | ||
var | ||
*.pyc | ||
*.swp | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
|
||
# Created by https://www.gitignore.io/api/python | ||
|
||
### Python ### | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# End of https://www.gitignore.io/api/python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Replicate on FloydHub | ||
|
||
## CIFAR | ||
|
||
nvidia-docker run --rm -it --ipc=host -v $(pwd):/dcgan -v /home/pirate/Downloads/cifar-10:/dcgan/cifar-10 -w /dcgan pytorch/pytorch:latest | ||
|
||
Training: | ||
|
||
- python main.py --dataset cifar10 --dataroot cifar-10 --outf cifar-10/result --cuda --ngpu 1 --niter 2 | ||
|
||
|
||
## LFW | ||
|
||
nvidia-docker run --rm -it --ipc=host -v $(pwd):/dcgan -v /home/pirate/Downloads/lfw:/dcgan/lfw -w /dcgan pytorch/pytorch:latest | ||
|
||
Training: | ||
|
||
- python main.py --dataset lfw --dataroot lfw --outf lfw/result --cuda --ngpu 1 --niter 2 | ||
|
||
Generating | ||
|
||
Random | ||
- python generate.py --netG lfw/result/netG_epoch_0.pth --outf lfw/result | ||
|
||
Provide a Vector | ||
- python generate.py --netG lfw/result/netG_epoch_0.pth --Zvector lfw/result/fixed_noise.pth --outf lfw/result | ||
- python generate.py --netG lfw/result/netG_epoch_25.pth --Zvector lfw/result/fixed_noise.pth --outf lfw/result | ||
- python generate.py --netG lfw/result/netG_epoch_50.pth --Zvector lfw/result/fixed_noise.pth --outf lfw/result | ||
- python generate.py --netG lfw/result/netG_epoch_69.pth --Zvector lfw/result/fixed_noise.pth --outf lfw/result | ||
|
||
## FLOYDHUB Training | ||
|
||
floyd run --gpu --env pytorch --data samit/datasets/lfw/1:lfw "python main.py --dataset lfw --dataroot /lfw --outf /output --cuda --ngpu 1 --niter 1" | ||
|
||
## FLOYDHUB Generating | ||
|
||
floyd run --gpu --env pytorch --data redeipirati/projects/dcgan/12/output:model "python generate.py --netG /model/netG_epoch_69.pth --outf" | ||
|
||
## FLOYDHUB Serving | ||
|
||
floyd run --gpu --mode serve --env pytorch --data redeipirati/projects/dcgan/12/output:model | ||
|
||
- GET req (random zvector, parameter checkpoint) | ||
curl -X GET -o <NAME_&_PATH_DOWNLOADED_IMG> -F "ckp=<MODEL_CHECKPOINT>" <SERVICE_ENDPOINT> | ||
curl -X GET -o prova.png -F "ckp=netG_epoch_69.pth" https://www.floydhub.com/expose/wQURz6s7Q56HbLeSrRGNCL | ||
|
||
- POST req (upload zvector, parameter checkpoint) | ||
curl -X POST -o <NAME_&_PATH_DOWNLOADED_IMG> -F "file=@<ZVECTOR_SERIALIZED_PATH>" <SERVICE_ENDPOINT> | ||
curl -X POST -o prova.png -F "file=@./lfw/result/fixed_noise.pth" https://www.floydhub.com/expose/wQURz6s7Q56HbLeSrRGNCL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" | ||
Flask Serving | ||
This file is a sample flask app that can be used to test your model with an REST API. | ||
This app does the following: | ||
- Look for a Zvector and/or n_samples parameters | ||
- Returns the output file generated at /output | ||
Additional configuration: | ||
- You can also choose the checkpoint file name to use as a request parameter | ||
- Parameter name: checkpoint | ||
- It is loaded from /input | ||
""" | ||
import os | ||
import torch | ||
from flask import Flask, send_file, request | ||
from werkzeug.exceptions import BadRequest | ||
from werkzeug.utils import secure_filename | ||
from dcgan import DCGAN | ||
|
||
ALLOWED_EXTENSIONS = set(['pth']) | ||
|
||
MODEL_PATH = '/model' | ||
print('Loading model from path: %s' % MODEL_PATH) | ||
OUTPUT_PATH = "/output/generated.png" | ||
|
||
app = Flask('DCGAN-Generator') | ||
|
||
# 2 possible parameters - checkpoint, zinput(file.cpth) | ||
# Return an Image | ||
@app.route('/<path:path>', methods=['GET', 'POST']) | ||
def geneator_handler(path): | ||
zvector = None | ||
batchSize = 1 | ||
# Upload a serialized Zvector | ||
if request.method == 'POST': | ||
# DO things | ||
# check if the post request has the file part | ||
if 'file' not in request.files: | ||
return BadRequest("File not present in request") | ||
file = request.files['file'] | ||
if file.filename == '': | ||
return BadRequest("File name is not present in request") | ||
if not allowed_file(file.filename): | ||
return BadRequest("Invalid file type") | ||
filename = secure_filename(file.filename) | ||
input_filepath = os.path.join('/output', filename) | ||
file.save(input_filepath) | ||
# Load a Z vector and Retrieve the N of samples to generate | ||
zvector = torch.load(input_filepath) | ||
batchSize = zvector.size()[0] | ||
|
||
checkpoint = request.form.get("ckp") or "netG_epoch_69.pth" | ||
Generator = DCGAN(netG=os.path.join(MODEL_PATH, checkpoint), zvector=zvector, batchSize=batchSize) | ||
Generator.build_model() | ||
Generator.generate() | ||
return send_file(OUTPUT_PATH, mimetype='image/png') | ||
|
||
|
||
def allowed_file(filename): | ||
return '.' in filename and \ | ||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | ||
|
||
if __name__ == '__main__': | ||
app.run(host='0.0.0.0') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
""" Serving DCGAN | ||
""" | ||
# TODO: Error check | ||
from __future__ import print_function | ||
import os | ||
import random | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim as optim | ||
import torch.utils.data | ||
import torchvision.transforms as transforms | ||
import torchvision.utils as vutils | ||
from torch.autograd import Variable | ||
|
||
# Number of colours | ||
NC = 3 | ||
# Latent Vector Size | ||
NZ = 100 | ||
# Number Gen filter | ||
NGF = 64 | ||
|
||
# Custom weights initialization called on netG and netD | ||
def weights_init(m): | ||
classname = m.__class__.__name__ | ||
if classname.find('Conv') != -1: | ||
m.weight.data.normal_(0.0, 0.02) | ||
elif classname.find('BatchNorm') != -1: | ||
m.weight.data.normal_(1.0, 0.02) | ||
m.bias.data.fill_(0) | ||
|
||
|
||
class _netG(nn.Module): | ||
"""Generator model""" | ||
def __init__(self, ngpu): | ||
super(_netG, self).__init__() | ||
self.ngpu = ngpu | ||
self.main = nn.Sequential( | ||
# input is Z, going into a convolution | ||
nn.ConvTranspose2d(NZ, NGF * 8, 4, 1, 0, bias=False), | ||
nn.BatchNorm2d(NGF * 8), | ||
nn.ReLU(True), | ||
# state size. (ngf*8) x 4 x 4 | ||
nn.ConvTranspose2d(NGF * 8, NGF * 4, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(NGF * 4), | ||
nn.ReLU(True), | ||
# state size. (ngf*4) x 8 x 8 | ||
nn.ConvTranspose2d(NGF * 4, NGF * 2, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(NGF * 2), | ||
nn.ReLU(True), | ||
# state size. (ngf*2) x 16 x 16 | ||
nn.ConvTranspose2d(NGF * 2, NGF, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(NGF), | ||
nn.ReLU(True), | ||
# state size. (ngf) x 32 x 32 | ||
nn.ConvTranspose2d(NGF, NC, 4, 2, 1, bias=False), | ||
nn.Tanh() | ||
# state size. (nc) x 64 x 64 | ||
) | ||
|
||
|
||
def forward(self, input): | ||
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | ||
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | ||
else: | ||
output = self.main(input) | ||
return output | ||
|
||
|
||
class DCGAN(object): | ||
"""DCGAN - Generative Model class""" | ||
def __init__(self, | ||
netG, | ||
zvector=None, | ||
batchSize=1, | ||
imageSize=64, | ||
nz=100, | ||
ngf=64, | ||
cuda=False, | ||
ngpu=1, | ||
outf="/output"): | ||
""" | ||
DCGAN - netG Builder | ||
Args: | ||
netG: path to netG (to continue training) | ||
zvector: a Tensor of shape (batchsize, nz, 1, 1) | ||
batchSize: int, input batch size, default 64 | ||
imageSize: int, the height / width of the input image to network, | ||
default 64 | ||
nz: int, size of the latent z vector, default 100 | ||
ngf: int, default 64 | ||
cuda: bool, enables cuda, default False | ||
ngpu: int, number of GPUs to use | ||
outf: string, folder to output images, default output | ||
""" | ||
# Path to Gen weight | ||
self._netG = netG | ||
# Latent Z Vector | ||
self._zvector = zvector | ||
# Number of sample to process | ||
self._batchSize = batchSize | ||
# Latent Z vector dim | ||
self._nz = int(nz) | ||
NZ = int(nz) | ||
# Number Gen Filter | ||
self._ngf = int(ngf) | ||
NGF = int(ngf) | ||
# Load netG | ||
try: | ||
torch.load(netG) | ||
self._netG = netG | ||
pass | ||
except IOError as e: | ||
# Does not exist OR no read permissions | ||
print ("Unable to open netG file") | ||
# Use Cuda | ||
self._cuda = cuda | ||
# How many GPU | ||
self._ngpu = int(ngpu) | ||
# Create outf if not exists | ||
try: | ||
os.makedirs(outf) | ||
except OSError: | ||
pass | ||
self._outf = outf | ||
|
||
|
||
# Build the model loading the weights | ||
def build_model(self): | ||
cudnn.benchmark = True | ||
# Build and load the model | ||
self._model = _netG(self._ngpu) | ||
self._model.apply(weights_init) | ||
self._model.load_state_dict(torch.load(self._netG)) | ||
# If provided use Zvector else create a random input normalized | ||
if self._zvector is not None: | ||
self._input = self._zvector | ||
else: | ||
self._input = torch.FloatTensor(self._batchSize, self._nz, 1, 1).normal_(0, 1) | ||
# cuda? | ||
if self._cuda: | ||
self._model.cuda() | ||
self._input = self._input.cuda() | ||
self._input = Variable(self._input) | ||
|
||
|
||
# Generate the image and store in the output folder | ||
def generate(self): | ||
#print (self._input) | ||
fake = self._model(self._input) | ||
vutils.save_image(fake.data, | ||
'%s/generated.png' % (self._outf), | ||
normalize=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
flask |
Oops, something went wrong.