-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
An example of gan implemented by DL4J #1030
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
An example of a simple gan implemented with DL4J | ||
|
||
***** ******** ***************** | ||
z ---- * G *----* G(z) * ------ * discriminator * ---- fake | ||
***** ******** * * | ||
x ----------------------------- ***************** ---- real |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<groupId>org.deeplearning4j</groupId> | ||
<artifactId>dl4j-gan-examples</artifactId> | ||
<version>1.0.0-SNAPSHOT</version> | ||
|
||
<!-- Properties Section. Change ND4J versions here, if required --> | ||
<properties> | ||
<dl4j-master.version>1.0.0-beta7</dl4j-master.version> | ||
<logback.version>1.2.3</logback.version> | ||
<java.version>1.8</java.version> | ||
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
</properties> | ||
|
||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>org.deeplearning4j</groupId> | ||
<artifactId>deeplearning4j-core</artifactId> | ||
<version>${dl4j-master.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.nd4j</groupId> | ||
<artifactId>nd4j-native</artifactId> | ||
<version>${dl4j-master.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>ch.qos.logback</groupId> | ||
<artifactId>logback-classic</artifactId> | ||
<version>${logback.version}</version> | ||
</dependency> | ||
<dependency> | ||
<groupId>junit</groupId> | ||
<artifactId>junit</artifactId> | ||
<version>4.12</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
</dependencies> | ||
|
||
<build> | ||
<plugins> | ||
<!-- Maven compiler plugin: compile for Java 8 --> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-compiler-plugin</artifactId> | ||
<version>3.5.1</version> | ||
<configuration> | ||
<source>${java.version}</source> | ||
<target>${java.version}</target> | ||
</configuration> | ||
</plugin> | ||
|
||
|
||
<!-- | ||
Maven shade plugin configuration: this is required so that if you build a single JAR file (an "uber-jar") | ||
it will contain all the required native libraries, and the backends will work correctly. | ||
Used for example when running the following commants | ||
|
||
mvn package | ||
cd target | ||
java -cp deeplearning4j-examples-1.0.0-beta-bin.jar org.deeplearning4j.LenetMnistExample | ||
--> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-shade-plugin</artifactId> | ||
<version>${maven-shade-plugin.version}</version> | ||
<configuration> | ||
<shadedArtifactAttached>true</shadedArtifactAttached> | ||
<shadedClassifierName>bin</shadedClassifierName> | ||
<createDependencyReducedPom>true</createDependencyReducedPom> | ||
<filters> | ||
<filter> | ||
<artifact>*:*</artifact> | ||
<excludes> | ||
<exclude>org/datanucleus/**</exclude> | ||
<exclude>META-INF/*.SF</exclude> | ||
<exclude>META-INF/*.DSA</exclude> | ||
<exclude>META-INF/*.RSA</exclude> | ||
</excludes> | ||
</filter> | ||
</filters> | ||
</configuration> | ||
|
||
<executions> | ||
<execution> | ||
<phase>package</phase> | ||
<goals> | ||
<goal>shade</goal> | ||
</goals> | ||
<configuration> | ||
<transformers> | ||
<transformer | ||
implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer"> | ||
<resource>reference.conf</resource> | ||
</transformer> | ||
<transformer | ||
implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> | ||
<transformer | ||
implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> | ||
</transformer> | ||
</transformers> | ||
</configuration> | ||
</execution> | ||
</executions> | ||
</plugin> | ||
</plugins> | ||
</build> | ||
</project> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package org.deeplearning4j.ganexamples; | ||
|
||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
|
||
import javax.swing.*; | ||
import java.awt.*; | ||
import java.awt.image.BufferedImage; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* @author zdl | ||
*/ | ||
public class MNISTVisualizer { | ||
private double imageScale; | ||
private List<INDArray> digits; | ||
private String title; | ||
private int gridWidth; | ||
private JFrame frame; | ||
|
||
public MNISTVisualizer(double imageScale, String title) { | ||
this(imageScale, title, 5); | ||
} | ||
|
||
public MNISTVisualizer(double imageScale, String title, int gridWidth) { | ||
this.imageScale = imageScale; | ||
this.title = title; | ||
this.gridWidth = gridWidth; | ||
} | ||
|
||
public void visualize() { | ||
if (null != frame) { | ||
frame.dispose(); | ||
} | ||
frame = new JFrame(); | ||
frame.setTitle(title); | ||
frame.setSize(800, 600); | ||
JPanel panel = new JPanel(); | ||
panel.setPreferredSize(new Dimension(800, 600)); | ||
panel.setLayout(new GridLayout(0, gridWidth)); | ||
List<JLabel> list = getComponents(); | ||
for (JLabel image : list) { | ||
panel.add(image); | ||
} | ||
|
||
frame.add(panel); | ||
frame.setVisible(true); | ||
frame.pack(); | ||
} | ||
|
||
public List<JLabel> getComponents() { | ||
List<JLabel> images = new ArrayList<JLabel>(); | ||
for (INDArray arr : digits) { | ||
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); | ||
for (int i = 0; i < 784; i++) { | ||
bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i))); | ||
} | ||
Comment on lines
+54
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a very slow way of drawing the image. You can use |
||
ImageIcon orig = new ImageIcon(bi); | ||
Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28), | ||
Image.SCALE_DEFAULT); | ||
ImageIcon scaled = new ImageIcon(imageScaled); | ||
images.add(new JLabel(scaled)); | ||
} | ||
return images; | ||
} | ||
|
||
public List<INDArray> getDigits() { | ||
return digits; | ||
} | ||
|
||
public void setDigits(List<INDArray> digits) { | ||
this.digits = digits; | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package org.deeplearning4j.ganexamples; | ||
|
||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | ||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||
import org.deeplearning4j.nn.conf.layers.DenseLayer; | ||
import org.deeplearning4j.nn.conf.layers.OutputLayer; | ||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||
import org.deeplearning4j.nn.weights.WeightInit; | ||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | ||
import org.nd4j.linalg.activations.Activation; | ||
import org.nd4j.linalg.api.ndarray.INDArray; | ||
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution; | ||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||
import org.nd4j.linalg.factory.Nd4j; | ||
import org.nd4j.linalg.learning.config.RmsProp; | ||
import org.nd4j.linalg.lossfunctions.LossFunctions; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
/** | ||
* ***** ******** ***************** | ||
* z ---- * G *----* G(z) * ------ * discriminator * ---- fake | ||
* ***** ******** * * | ||
* x ----------------------------- ***************** ---- real | ||
* | ||
* @author zdl | ||
*/ | ||
public class SimpleGan { | ||
|
||
public static void main(String[] args) throws Exception { | ||
|
||
/** | ||
*Build the discriminator | ||
*/ | ||
MultiLayerConfiguration discriminatorConf = new NeuralNetConfiguration.Builder().seed(12345) | ||
.weightInit(WeightInit.XAVIER).updater(new RmsProp(0.001)) | ||
.list() | ||
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build()) | ||
.layer(1, new DenseLayer.Builder().activation(Activation.RELU) | ||
.nIn(512).nOut(256).build()) | ||
.layer(2, new DenseLayer.Builder().activation(Activation.RELU) | ||
.nIn(256).nOut(128).build()) | ||
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.XENT) | ||
.activation(Activation.SIGMOID).nIn(128).nOut(1).build()).build(); | ||
|
||
|
||
MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder().seed(12345) | ||
.weightInit(WeightInit.XAVIER) | ||
//generator | ||
.updater(new RmsProp(0.001)).list() | ||
.layer(0, new DenseLayer.Builder().nIn(20).nOut(256).activation(Activation.RELU).build()) | ||
.layer(1, new DenseLayer.Builder().activation(Activation.RELU) | ||
.nIn(256).nOut(512).build()) | ||
.layer(2, new DenseLayer.Builder().activation(Activation.RELU) | ||
.nIn(512).nOut(28 * 28).build()) | ||
//Freeze the discriminator parameter | ||
.layer(3, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(28 * 28).nOut(512).activation(Activation.RELU).build())) | ||
.layer(4, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(512).nOut(256).activation(Activation.RELU).build())) | ||
.layer(5, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU).build())) | ||
.layer(6, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.XENT) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you using a fully quallified name here instead of importing it? |
||
.activation(Activation.SIGMOID).nIn(128).nOut(1).build())).build(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For both networks, why don't you use the |
||
|
||
|
||
MultiLayerNetwork discriminatorNetwork = new MultiLayerNetwork(discriminatorConf); | ||
discriminatorNetwork.init(); | ||
System.out.println(discriminatorNetwork.summary()); | ||
discriminatorNetwork.setListeners(new ScoreIterationListener(1)); | ||
|
||
MultiLayerNetwork ganNetwork = new MultiLayerNetwork(ganConf); | ||
ganNetwork.init(); | ||
ganNetwork.setListeners(new ScoreIterationListener(1)); | ||
System.out.println(ganNetwork.summary()); | ||
|
||
DataSetIterator train = new MnistDataSetIterator(30, true, 12345); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are using the batch size in a few places, it would be better to have it as a variable, so it isn't just a magic number, but delivers some semantic meaning. |
||
|
||
INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1)); | ||
INDArray labelG = Nd4j.ones(30, 1); | ||
MNISTVisualizer mnistVisualizer = new MNISTVisualizer(1, "Gan"); | ||
for (int i = 1; i <= 100000; i++) { | ||
if (!train.hasNext()) { | ||
train.reset(); | ||
} | ||
INDArray trueImage = train.next().getFeatures(); | ||
INDArray z = Nd4j.rand(new NormalDistribution(), new long[]{30, 20}); | ||
List<INDArray> ganFeedForward = ganNetwork.feedForward(z, false); | ||
INDArray fakeImage = ganFeedForward.get(3); | ||
INDArray trainDiscriminatorFeatures = Nd4j.vstack(trueImage, fakeImage); | ||
//Training discriminator | ||
discriminatorNetwork.fit(trainDiscriminatorFeatures, labelD); | ||
copyDiscriminatorParam(discriminatorNetwork, ganNetwork); | ||
//Training generator | ||
ganNetwork.fit(z, labelG); | ||
if (i % 1000 == 0) { | ||
List<INDArray> indArrays = ganNetwork.feedForward(Nd4j.rand(new NormalDistribution(), new long[]{30, 20}), false); | ||
List<INDArray> list = new ArrayList<>(); | ||
INDArray indArray = indArrays.get(3); | ||
for (int j = 0; j < indArray.size(0); j++) { | ||
list.add(indArray.getRow(j)); | ||
} | ||
mnistVisualizer.setDigits(list); | ||
mnistVisualizer.visualize(); | ||
} | ||
} | ||
} | ||
|
||
public static void copyDiscriminatorParam(MultiLayerNetwork discriminatorNetwork, MultiLayerNetwork ganNetwork) { | ||
for (int i = 0; i <= 3; i++) { | ||
ganNetwork.getLayer(i + 3).setParams(discriminatorNetwork.getLayer(i).params()); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ensure that examples like this are actually runable.