- UNET GAN pix2pix with VGG19 generator and discriminator
- Table Of Contents
- Motivation and goals
- Pre-requisites
- Getting started
- Results
- Contributing
- Sources
- Explore the use cases of GAN
- Experiment with cloud computing resources
- Implement a GAN model with tensorflow
- Use pretrained models as a discriminator and generator
- Create a training pipeline for GAN models
Tensorflow was built from source using the following configuration:
python=3.10
tensorflow=2.9.3
cudatoolkit=11.2
cudnn=8.1
Have checked that this works on docker:
FROM tensorflow/tensorflow:2.9.3-gpu
RUN pip install hydra-core tqdm scipy matplotlib
A conda environment file will be provided in the root directory of this repository. It was only tested on a windows machine.
If you would like to use this project, follow these steps:
- Clone this repository
git clone https://github.com/wheynelau/VGG19-gan-experiment.git
- Install the requirements, before this you would need conda installed on your machine. You can install conda from here
conda env create -f environment.yml
- Setup conf/config.yaml
All available options are in the config.yaml file
- Setup the folders and files
If your image is in the format of two images combined together, you can use the 'preprocess.py' file to split them into two images.
Here is an example of the image:
Your directory should look like this:
python src/preprocess.py
├───data
│ ├───mask (optional)
│ ├───test
│ └───train
Running the 'preprocess.py' file will create a new directory based on the preprocess_path in the configuration and split the images into two images. It assumes that the images in the mask and test/train are the same names. This is how it would appear after running the 'preprocess.py' file:
$ python src/preprocess.py
├───$preprocess_path
│ ├───test
│ │ ├───image
│ │ └───target
│ └───train
│ ├───image
│ └───target
$ python train.py
At the end of train.py, I've added a statement to save the generator of the GAN model. This is the generator that will be used for inferencing.
Run the below command to infer on a folder containing images:
Note: There is no exception handling for non-image files, please input only image files In addition, all images will be resized to 256x256
$ python infer.py
Results are in the RESULTS.md file
- Implemented a GAN model from scratch in tensorflow.
- Used pretrained models as the generator and discriminator.
- Improved the training pipeline by method overridding the compile and train_step method.
- Custom callbacks for Tensorboard, checkpointing and optimiser learning rate scheduling.
- GANs are hard to train, especially when the generator and discriminator are not balanced.
- Encountered mode collapse, where the generator generated almost the same image for all the images in the dataset.
- Difficult to achieve equilibrium between the generator and discriminator.
- Explore other loss functions
- Introduce noise to the input images
- Use cGAN instead of GAN
- Implement pytorch version of the model
- Step discriminator less than generator
- Experiment if cycleGAN works better.
Feedback and contributions are welcome. As this is a learning project, I may not be active in maintaining this repository and accepting pull requests.
APDrawingGAN # Found the datasets from this repository