Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An example of gan implemented by DL4J #1030

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dl4j-gan-examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
An example of a simple gan implemented with DL4J

***** ******** *****************
z ---- * G *----* G(z) * ------ * discriminator * ---- fake
***** ******** * *
x ----------------------------- ***************** ---- real
113 changes: 113 additions & 0 deletions dl4j-gan-examples/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-gan-examples</artifactId>
<version>1.0.0-SNAPSHOT</version>

<!-- Properties Section. Change ND4J versions here, if required -->
<properties>
<dl4j-master.version>1.0.0-beta7</dl4j-master.version>
<logback.version>1.2.3</logback.version>
<java.version>1.8</java.version>
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>


<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>compile</scope>
</dependency>
</dependencies>

<build>
<plugins>
<!-- Maven compiler plugin: compile for Java 8 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>


<!--
Maven shade plugin configuration: this is required so that if you build a single JAR file (an "uber-jar")
it will contain all the required native libraries, and the backends will work correctly.
Used for example when running the following commants

mvn package
cd target
java -cp deeplearning4j-examples-1.0.0-beta-bin.jar org.deeplearning4j.LenetMnistExample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please ensure that examples like this are actually runable.

-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<configuration>
<shadedArtifactAttached>true</shadedArtifactAttached>
<shadedClassifierName>bin</shadedClassifierName>
<createDependencyReducedPom>true</createDependencyReducedPom>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>org/datanucleus/**</exclude>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>

<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer
implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer
implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer
implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
</transformer>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package org.deeplearning4j.ganexamples;

import org.nd4j.linalg.api.ndarray.INDArray;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.List;

/**
* @author zdl
*/
public class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits;
private String title;
private int gridWidth;
private JFrame frame;

public MNISTVisualizer(double imageScale, String title) {
this(imageScale, title, 5);
}

public MNISTVisualizer(double imageScale, String title, int gridWidth) {
this.imageScale = imageScale;
this.title = title;
this.gridWidth = gridWidth;
}

public void visualize() {
if (null != frame) {
frame.dispose();
}
frame = new JFrame();
frame.setTitle(title);
frame.setSize(800, 600);
JPanel panel = new JPanel();
panel.setPreferredSize(new Dimension(800, 600));
panel.setLayout(new GridLayout(0, gridWidth));
List<JLabel> list = getComponents();
for (JLabel image : list) {
panel.add(image);
}

frame.add(panel);
frame.setVisible(true);
frame.pack();
}

public List<JLabel> getComponents() {
List<JLabel> images = new ArrayList<JLabel>();
for (INDArray arr : digits) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i)));
}
Comment on lines +54 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very slow way of drawing the image.

You can use Java2DNativeImageLoader to turn an INDArray into an JavaCV Mat, then use OpenCVFrameConverter.ToMat to convert that into a Frame and Java2DFrameConverter to copy that frame into an already existing BufferedImage.

ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28),
Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}

public List<INDArray> getDigits() {
return digits;
}

public void setDigits(List<INDArray> digits) {
this.digits = digits;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package org.deeplearning4j.ganexamples;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;

/**
* ***** ******** *****************
* z ---- * G *----* G(z) * ------ * discriminator * ---- fake
* ***** ******** * *
* x ----------------------------- ***************** ---- real
*
* @author zdl
*/
public class SimpleGan {

public static void main(String[] args) throws Exception {

/**
*Build the discriminator
*/
MultiLayerConfiguration discriminatorConf = new NeuralNetConfiguration.Builder().seed(12345)
.weightInit(WeightInit.XAVIER).updater(new RmsProp(0.001))
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build())
.layer(1, new DenseLayer.Builder().activation(Activation.RELU)
.nIn(512).nOut(256).build())
.layer(2, new DenseLayer.Builder().activation(Activation.RELU)
.nIn(256).nOut(128).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
.activation(Activation.SIGMOID).nIn(128).nOut(1).build()).build();


MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder().seed(12345)
.weightInit(WeightInit.XAVIER)
//generator
.updater(new RmsProp(0.001)).list()
.layer(0, new DenseLayer.Builder().nIn(20).nOut(256).activation(Activation.RELU).build())
.layer(1, new DenseLayer.Builder().activation(Activation.RELU)
.nIn(256).nOut(512).build())
.layer(2, new DenseLayer.Builder().activation(Activation.RELU)
.nIn(512).nOut(28 * 28).build())
//Freeze the discriminator parameter
.layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build()))
.layer(4, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(512).nOut(256).activation(Activation.RELU).build()))
.layer(5, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU).build()))
.layer(6, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you using a fully quallified name here instead of importing it?

.activation(Activation.SIGMOID).nIn(128).nOut(1).build())).build();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both networks, why don't you use the setInputType functionality, so you don't have to set nIn on every layer?



MultiLayerNetwork discriminatorNetwork = new MultiLayerNetwork(discriminatorConf);
discriminatorNetwork.init();
System.out.println(discriminatorNetwork.summary());
discriminatorNetwork.setListeners(new ScoreIterationListener(1));

MultiLayerNetwork ganNetwork = new MultiLayerNetwork(ganConf);
ganNetwork.init();
ganNetwork.setListeners(new ScoreIterationListener(1));
System.out.println(ganNetwork.summary());

DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using the batch size in a few places, it would be better to have it as a variable, so it isn't just a magic number, but delivers some semantic meaning.


INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
INDArray labelG = Nd4j.ones(30, 1);
MNISTVisualizer mnistVisualizer = new MNISTVisualizer(1, "Gan");
for (int i = 1; i <= 100000; i++) {
if (!train.hasNext()) {
train.reset();
}
INDArray trueImage = train.next().getFeatures();
INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 20});
List<INDArray> ganFeedForward = ganNetwork.feedForward(z, false);
INDArray fakeImage = ganFeedForward.get(3);
INDArray trainDiscriminatorFeatures = Nd4j.vstack(trueImage, fakeImage);
//Training discriminator
discriminatorNetwork.fit(trainDiscriminatorFeatures, labelD);
copyDiscriminatorParam(discriminatorNetwork, ganNetwork);
//Training generator
ganNetwork.fit(z, labelG);
if (i % 1000 == 0) {
List<INDArray> indArrays = ganNetwork.feedForward(Nd4j.rand(new NormalDistribution(), new long[]{30, 20}), false);
List<INDArray> list = new ArrayList<>();
INDArray indArray = indArrays.get(3);
for (int j = 0; j < indArray.size(0); j++) {
list.add(indArray.getRow(j));
}
mnistVisualizer.setDigits(list);
mnistVisualizer.visualize();
}
}
}

public static void copyDiscriminatorParam(MultiLayerNetwork discriminatorNetwork, MultiLayerNetwork ganNetwork) {
for (int i = 0; i <= 3; i++) {
ganNetwork.getLayer(i + 3).setParams(discriminatorNetwork.getLayer(i).params());
}
}
}