Skip to content

Commit

Permalink
move repo from private git to public
Browse files Browse the repository at this point in the history
  • Loading branch information
philippwitte committed Feb 7, 2020
1 parent 7c466fd commit 4f46b57
Show file tree
Hide file tree
Showing 35 changed files with 4,020 additions and 1 deletion.
400 changes: 400 additions & 0 deletions Manifest.toml

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte <[email protected]>"]
version = "0.1.0"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
65 changes: 64 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,65 @@
# InvertibleNetworks.jl
A Julia framework for invertible neural networks

Building blocks for invertible neural networks in the Julia programming language.

## Installation

```
] dev https://github.gatech.edu/pwitte3/InvertibleNetworks
```

## Building blocks

- 1x1 Convolutions using Householder transformations ([example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/convolution_1x1.jl))

- Residual block for pixel shuffeling (Putzky and Welling, 2019) ([example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/residual_block.jl))

- Invertible coupling layer from Putzky and Welling (2019) ([example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/invertible_layer.jl))

- Invertible coupling layer from Dinh et al. (2017) ([example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/coupling_layer.jl))

- Activation normalization (Kingma and Dhariwal, 2018) ([example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/activation_normalization.jl))

- Various activation functions (Sigmoid, ReLU, leaky ReLU, GaLU)

- Dimensionality manipulation: squeeze/unsqueeze (column, patch, checkerboard), split/cat

- Squeeze/unsqueeze using the wavelet transform


## Applications

- Invertible recurrent inference machines (Putzky and Welling, 2019) ([generic example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/loop_unrolling.jl), [seismic example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/i-rim_seismic.jl))

- Generative models with maximum likelihood via the change of variable formula ([example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/generative_model_change_of_variable.jl))

- Glow: Generative flow with invertible 1x1 convolutions (Kingma and Dhariwal, 2018) ([generic example](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/examples/glow_likelihood_logdet.jl), [source](https://github.gatech.edu/pwitte3/InvertibleNetworks/blob/master/src/invertible_network_glow.jl))

## To Do

- GPU support


## Acknowledgments

This package uses functions from [NNlib.jl](https://github.com/FluxML/NNlib.jl), [Flux.jl](https://github.com/FluxML/Flux.jl) and [Wavelets.jl](https://github.com/JuliaDSP/Wavelets.jl)


## References

- Yann Dauphin, Angela Fan, Michael Auli and David Grangier, "Language modeling with gated convolutional networks", Proceedings of the 34th International Conference on Machine Learning, 2017. https://arxiv.org/pdf/1612.08083.pdf

- Laurent Dinh, Jascha Sohl-Dickstein and Samy Bengio, "Density estimation using Real NVP", International Conference on Learning Representations, 2017, https://arxiv.org/abs/1605.08803

- Diederik P. Kingma and Prafulla Dhariwal, "Glow: Generative Flow with Invertible 1x1 Convolutions", Conference on Neural Information Processing Systems, 2018. https://arxiv.org/abs/1807.03039

- Patrick Putzky and Max Welling, "Invert to learn to invert", Advances in Neural Information Processing Systems, 2019. https://arxiv.org/pdf/1911.10914.pdf


## Author

- Philipp Witte, Georgia Institute of Technology

- Contact at [email protected]


28 changes: 28 additions & 0 deletions examples/activation_normalization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Activation normalization from Kingma & Dhariwal (2018)
# Author: Philipp Witte, [email protected]
# Date: January 2020

using InvertibleNetworks, LinearAlgebra, Test

# Input
nx = 64
ny = 64
k = 10
batchsize = 4

# Input image: nx x ny x k x batchsize
X = randn(Float32, nx, ny, k, batchsize)
Y = randn(Float32, nx, ny, k, batchsize)

# Activation normalization
AN = ActNorm(k; logdet=true)

# Test invertibility
Y_, logdet = AN.forward(X)
ΔY = Y_ - Y

# Backpropagation
ΔX, X_ = AN.backward(ΔY, Y_)

# Test invertibility
isapprox(norm(X - X_)/norm(X), 0f0, atol=1f-6)
37 changes: 37 additions & 0 deletions examples/convolution_1x1.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Example for using the 1x1 convolution operator to permute an image along the channel dimension.
# Author: Philipp Witte, [email protected]
# Date: January 2020

using InvertibleNetworks, LinearAlgebra, Test

# Dimensions
nx = 64 # no. of pixels in x dimension
ny = 64 # no. of pixels in y dimension
k = 10 # no. of channels
batchsize = 4

# Input image: nx x ny x k x batchsize
X = glorot_uniform(nx, ny, k, batchsize)

# 1x1 convolution operators
C = Conv1x1(k)
C0 = Conv1x1(k)

# Generate "true/observed" data with the same dimension as X
Y = C.forward(X)

# Predicted data
Y0 = C0.forward(X)
@test isnothing(C0.v1.grad) # after forward pass, gradients are not set

# Data residual
ΔY = Y0 - Y

# Backward pass: Pass ΔY to compute ΔX, the gradient w.r.t the input X.
# Also pass Y0 to recompute the forward state X using the inverse mapping
# and use it to compute the derivative w.r.t. the coefficients of the
# Householder matrix.
ΔX, X_ = C0.inverse((ΔY, Y0)) # returns derivative w.r.t input and the recomputed input itself
@test ~isnothing(C0.v1.grad) # after inverse pass, gradients are set
@test isapprox(norm(X - X_)/norm(X), 0f0, atol=1f-6) # X and X_ should be the same

34 changes: 34 additions & 0 deletions examples/coupling_layer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Invertible CNN layer from Dinh et al. (2017)/Kingma & Dhariwal (2019)
# Author: Philipp Witte, [email protected]
# Date: January 2020

using InvertibleNetworks, LinearAlgebra, Test

# Input
nx = 64
ny = 64
k = 20
n_in = 10
n_hidden = 20
batchsize = 2
k1 = 4
k2 = 3

# Input image
X = glorot_uniform(nx, ny, k, batchsize)
X0 = glorot_uniform(nx, ny, k, batchsize)

# 1x1 convolution and residual blocks
C = Conv1x1(k)
RB = ResidualBlock(nx, ny, n_in, n_hidden, batchsize; k1=k1, k2=k2, fan=true)

# Invertible splitting layer
L = CouplingLayer(C, RB; logdet=true) # compute logdet

# Forward + backward
Y = L.forward(X)[1]
Y0, logdet = L.forward(X0)
ΔY = Y0 - Y
ΔX, X0_ = L.backward(ΔY, Y0)

@test isapprox(norm(X0_ - X0)/norm(X0), 0f0, atol=1f-2)
103 changes: 103 additions & 0 deletions examples/generative_model_change_of_variable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Generative model using the change of variables formula
# Author: Philipp Witte, [email protected]
# Date: January 2020

using LinearAlgebra, InvertibleNetworks, PyPlot, Flux, Random
import Flux.Optimise.update!


# Target distribution
function sample_banana(batchsize; c=[1f0, 4f0])
x = randn(Float32, 2, batchsize)
y = zeros(Float32, 1, 1, 2, batchsize)
y[1,1,1,:] = x[1,:] ./ c[1]
y[1,1,2,:] = x[2,:].*c[1] + c[1].*c[2].*(x[1,:].^2 .+ c[1]^2)
return y
end

####################################################################################################

# Define network
nx = 1
ny = 1
n_in = 2
n_hidden = 128
batchsize = 100
depth = 10
AN = Array{ActNorm}(undef, depth)
L = Array{CouplingLayer}(undef, depth)
Params = Array{Parameter}(undef, 0)

# Create layers
for j=1:depth
AN[j] = ActNorm(n_in; logdet=true)
L[j] = CouplingLayer(nx, ny, n_in, n_hidden, batchsize; k1=1, k2=1, p1=0, p2=0, logdet=true)

# Collect parameters
global Params = cat(Params, get_params(AN[j]); dims=1)
global Params = cat(Params, get_params(L[j]); dims=1)
end

# Forward pass
function forward(X)
logdet = 0f0
for j=1:depth
X_, logdet1 = AN[j].forward(X)
X, logdet2 = L[j].forward(X_)
logdet += (logdet1 + logdet2)
end
return X, logdet
end

# Backward pass
function backward(ΔX, X)
for j=depth:-1:1
ΔX_, X_ = L[j].backward(ΔX, X)
ΔX, X = AN[j].backward(ΔX_, X_)
end
return ΔX, X
end

####################################################################################################

# Loss
function loss(X)
Y_, logdet = forward(X)
f = .5f0/batchsize*norm(Y_)^2 - logdet
ΔX = backward(1f0/batchsize*Y_, Y_)[1]
return f, ΔX
end

ntrain = 500
X = sample_banana(ntrain)
maxiter = 5000
opt = Flux.ADAM(1f-5)
fval = zeros(Float32, maxiter)
for j=1:maxiter

# Evaluate objective and gradients
#idx = randperm(ntrain)[1:batchsize]
fval[j] = loss(X)[1]
mod(j, 1) == 0 && (print("Iteration: ", j, "; f = ", fval[j], "\n"))

# Update params
for p in Params
update!(opt, p.data, p.grad)
end
clear_grad!(Params)
end

####################################################################################################

# Testing
test_size = 1000
#X = sample_banana(test_size)
Y_ = forward(X)[1]
Y = randn(Float32, 1, 1, 2, test_size)
X_ = backward(Y, Y)[2]

figure(figsize=[8,8])
subplot(2,2,1); plot(X[1, 1, 1, :], X[1, 1, 2, :], "."); title(L"Data space: $x \sim \hat{p}_X$")
subplot(2,2,2); plot(Y_[1, 1, 1, :], Y_[1, 1, 2, :], "g."); title(L"Latent space: $z = f(x)$")
subplot(2,2,3); plot(X_[1, 1, 1, :], X_[1, 1, 2, :], "g."); title(L"Data space: $x = f^{-1}(z)$")
subplot(2,2,4); plot(Y[1, 1, 1, :], Y[1, 1, 2, :], "."); title(L"Latent space: $z \sim \hat{p}_Z$")
39 changes: 39 additions & 0 deletions examples/glow_likelihood_logdet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generative model w/ Glow architecture from Kingma & Dhariwal (2018)
# Author: Philipp Witte, [email protected]
# Date: January 2020

using InvertibleNetworks, LinearAlgebra, Flux
import Flux.Optimise.update!

# Define network
nx = 64 # must be multiple of 2
ny = 64
n_in = 3
n_hidden = 32
batchsize = 10
L = 2 # number of scales
K = 2 # number of flow steps per scale

# Input
X = rand(Float32, nx, ny, n_in, batchsize)

# Glow network
G = NetworkGlow(nx, ny, n_in, batchsize, n_hidden, L, K)

# Objective function
function loss(X)
Y, logdet = G.forward(X)
f = .5f0/batchsize*norm(Y)^2 - logdet
ΔX, X_ = G.backward(1f0./batchsize*Y, Y)
return f
end

# Evaluate loss
f = loss(X)

# Update weights
opt = Flux.ADAM()
Params = get_params(G)
for p in Params
update!(opt, p.data, p.grad)
end
Loading

0 comments on commit 4f46b57

Please sign in to comment.