Skip to content

Commit

Permalink
Merge pull request #40 from slimgroup/glow_splitscales
Browse files Browse the repository at this point in the history
add split_scales option to glow
  • Loading branch information
grizzuti authored Nov 21, 2021
2 parents 0f8603e + 565b40c commit 1961cec
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 31 deletions.
83 changes: 56 additions & 27 deletions src/networks/invertible_network_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ export NetworkGlow, NetworkGlow3D
- `K`: number of flow steps per scale (inner loop)
- `split_scales`: if true, perform squeeze operation which halves spatial dimensions and duplicates channel dimensions
then split output in half along channel dimension after each scale. Feed one half through the next layers,
while saving the remaining channels for the output.
- `k1`, `k2`: kernel size of convolutions in residual block. `k1` is the kernel of the first and third
operator, `k2` is the kernel size of the second operator.
Expand Down Expand Up @@ -56,84 +60,98 @@ export NetworkGlow, NetworkGlow3D
struct NetworkGlow <: InvertibleNetwork
AN::AbstractArray{ActNorm, 2}
CL::AbstractArray{CouplingLayerGlow, 2}
Z_dims::AbstractArray{Tuple, 1}
Z_dims::Union{AbstractArray{Tuple, 1}, Nothing}
L::Int64
K::Int64
split_scales::Bool
end

@Flux.functor NetworkGlow

# Constructor
function NetworkGlow(n_in, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2)
function NetworkGlow(n_in, n_hidden, L, K; split_scales=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2)

AN = Array{ActNorm}(undef, L, K) # activation normalization
CL = Array{CouplingLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block
Z_dims = Array{Tuple}(undef, L-1) # save dimensions for inverse/backward pass

if split_scales
Z_dims = Array{Tuple}(undef, L-1) # save dimensions for inverse/backward pass
channel_factor = 4
else
Z_dims = nothing
channel_factor = 1
end

for i=1:L
n_in *= 4 # squeeze
n_in *= channel_factor # squeeze if split_scales is turned on
for j=1:K
AN[i, j] = ActNorm(n_in; logdet=true)
CL[i, j] = CouplingLayerGlow(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, ndims=ndims)
end
(i < L) && (n_in = Int64(n_in/2)) # split
(i < L && split_scales) && (n_in = Int64(n_in/2)) # split
end

return NetworkGlow(AN, CL, Z_dims, L, K)
return NetworkGlow(AN, CL, Z_dims, L, K, split_scales)
end

NetworkGlow3D(args; kw...) = NetworkGlow(args...; kw..., ndims=3)


# Forward pass and compute logdet
function forward(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
Z_save = array_of_array(X, G.L-1)
G.split_scales && (Z_save = array_of_array(X, G.L-1))

logdet = 0
for i=1:G.L
X = squeeze(X; pattern="checkerboard")
(G.split_scales) && (X = squeeze(X; pattern="checkerboard"))
for j=1:G.K
X, logdet1 = G.AN[i, j].forward(X)
X, logdet2 = G.CL[i, j].forward(X)
logdet += (logdet1 + logdet2)
end
if i < G.L # don't split after last iteration
if G.split_scales && i < G.L # don't split after last iteration
X, Z = tensor_split(X)
Z_save[i] = Z
G.Z_dims[i] = size(Z)
end
end
X = cat_states(Z_save, X)
G.split_scales && (X = cat_states(Z_save, X))
return X, logdet
end

# Inverse pass and compute gradients
# Inverse pass
function inverse(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
Z_save, X = split_states(X, G.Z_dims)
G.split_scales && ((Z_save, X) = split_states(X, G.Z_dims))
for i=G.L:-1:1
if i < G.L
if G.split_scales && i < G.L
X = tensor_cat(X, Z_save[i])
end
for j=G.K:-1:1
X = G.CL[i, j].inverse(X)
X = G.AN[i, j].inverse(X)
end
X = unsqueeze(X; pattern="checkerboard")
(G.split_scales) && (X = unsqueeze(X; pattern="checkerboard"))
end
return X
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, G::NetworkGlow; set_grad::Bool=true) where {T, N}
ΔZ_save, ΔX = split_states(ΔX, G.Z_dims)
Z_save, X = split_states(X, G.Z_dims)

# Split data and gradients
if G.split_scales
ΔZ_save, ΔX = split_states(ΔX, G.Z_dims)
Z_save, X = split_states(X, G.Z_dims)
end

if ~set_grad
Δθ = Array{Parameter, 1}(undef, 10*G.L*G.K)
∇logdet = Array{Parameter, 1}(undef, 10*G.L*G.K)
end
blkidx = 10*G.L*G.K
for i=G.L:-1:1
if i < G.L
X = tensor_cat(X, Z_save[i])
if G.split_scales && i < G.L
X = tensor_cat(X, Z_save[i])
ΔX = tensor_cat(ΔX, ΔZ_save[i])
end
for j=G.K:-1:1
Expand All @@ -148,8 +166,10 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, G::NetworkGl
end
blkidx -= 10
end
X = unsqueeze(X; pattern="checkerboard")
ΔX = unsqueeze(ΔX; pattern="checkerboard")
if G.split_scales
X = unsqueeze(X; pattern="checkerboard")
ΔX = unsqueeze(ΔX; pattern="checkerboard")
end
end
set_grad ? (return ΔX, X) : (return ΔX, Δθ, X, ∇logdet)
end
Expand All @@ -158,14 +178,20 @@ end
## Jacobian-related utils

function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, G::NetworkGlow) where {T, N}
Z_save = array_of_array(ΔX, G.L-1)
ΔZ_save = array_of_array(ΔX, G.L-1)

if G.split_scales
Z_save = array_of_array(ΔX, G.L-1)
ΔZ_save = array_of_array(ΔX, G.L-1)
end
logdet = 0
GNΔθ = Array{Parameter, 1}(undef, 10*G.L*G.K)
blkidx = 0
for i=1:G.L
X = squeeze(X; pattern="checkerboard")
ΔX = squeeze(ΔX; pattern="checkerboard")
if G.split_scales
X = squeeze(X; pattern="checkerboard")
ΔX = squeeze(ΔX; pattern="checkerboard")
end

for j=1:G.K
Δθ_ij = Δθ[blkidx+1:blkidx+10]
ΔX, X, logdet1, GNΔθ1 = G.AN[i, j].jacobian(ΔX, Δθ_ij[1:2], X)
Expand All @@ -174,16 +200,19 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, G::Net
GNΔθ[blkidx+1:blkidx+10] = cat(GNΔθ1,GNΔθ2; dims=1)
blkidx += 10
end
if i < G.L # don't split after last iteration
if G.split_scales && i < G.L # don't split after last iteration
X, Z = tensor_split(X)
ΔX, ΔZ = tensor_split(ΔX)
Z_save[i] = Z
ΔZ_save[i] = ΔZ
G.Z_dims[i] = size(Z)
end
end
X = cat_states(Z_save, X)
ΔX = cat_states(ΔZ_save, ΔX)
if G.split_scales
X = cat_states(Z_save, X)
ΔX = cat_states(ΔZ_save, ΔX)
end

return ΔX, X, logdet, GNΔθ
end

Expand Down
4 changes: 3 additions & 1 deletion test/test_layers/test_coupling_layer_irim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# Author: Philipp Witte, [email protected]
# Date: January 2020

using InvertibleNetworks, LinearAlgebra, Test
using InvertibleNetworks, LinearAlgebra, Test, Random

Random.seed!(1);

# Input
nx = 28
Expand Down
153 changes: 151 additions & 2 deletions test/test_networks/test_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using InvertibleNetworks, LinearAlgebra, Test, Random

Random.seed!(1);

# Define network
nx = 32
ny = 32
Expand All @@ -13,20 +15,21 @@ batchsize = 2
L = 2
K = 2

###################################################################################################
###########################################Test with split_scales = false #########################
# Invertibility

# Network and input
G = NetworkGlow(n_in, n_hidden, L, K)
X = rand(Float32, nx, ny, n_in, batchsize)

Y = G.forward(X)[1]
X_ = G.backward(Y, Y)[2]
X_ = G.inverse(Y)

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

###################################################################################################
# Test gradients are set and cleared
G.backward(Y, Y)

P = get_params(G)
gsum = 0
Expand Down Expand Up @@ -149,6 +152,152 @@ end

# Adjoint test

set_params!(G, θ)
dY, Y, _, _ = G.jacobian(dX, dθ, X)
dY_ = randn(Float32, size(dY))
dX_, dθ_, _, _ = G.adjointJacobian(dY_, Y)
a = dot(dY, dY_)
b = dot(dX, dX_)+dot(dθ, dθ_)
@test isapprox(a, b; rtol=1f-3)


###########################################Test with split_scales = true #########################
# Invertibility

# Network and input
G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true)
X = rand(Float32, nx, ny, n_in, batchsize)

Y = G.forward(X)[1]
X_ = G.inverse(Y)

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

###################################################################################################
# Test gradients are set and cleared
G.backward(Y, Y)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, L*K*10)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, 0)


###################################################################################################
# Gradient test

function loss(G, X)
Y, logdet = G.forward(X)
f = -log_likelihood(Y) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y)
return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. input
G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true)
X = rand(Float32, nx, ny, n_in, batchsize)
X0 = rand(Float32, nx, ny, n_in, batchsize)
dX = X - X0

f0, ΔX = loss(G, X0)[1:2]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss(G, X0 + h*dX,)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f1)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f1)

# Gradient test w.r.t. parameters
X = rand(Float32, nx, ny, n_in, batchsize)
G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true)
G0 = NetworkGlow(n_in, n_hidden, L, K; split_scales=true)
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss(G0, X)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv

f = loss(G0, X)[1]
err3[j] = abs(f - f0)
err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
print(err3[j], "; ", err4[j], "\n")
global h = h/2f0
end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f1)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f1)


###################################################################################################
# Jacobian-related tests

# Gradient test

# Initialization
G = NetworkGlow(n_in, n_hidden, L, K; split_scales=true); G.forward(randn(Float32, nx, ny, n_in, batchsize))
θ = deepcopy(get_params(G))
G0 = NetworkGlow(n_in, n_hidden, L, K; split_scales=true); G0.forward(randn(Float32, nx, ny, n_in, batchsize))
θ0 = deepcopy(get_params(G0))
X = randn(Float32, nx, ny, n_in, batchsize)

# Perturbation (normalized)
= θ-θ0; dθ .*= norm.(θ0)./(norm.(dθ).+1f-10)
dX = randn(Float32, nx, ny, n_in, batchsize); dX *= norm(X)/norm(dX)

# Jacobian eval
dY, Y, _, _ = G.jacobian(dX, dθ, X)

# Test
print("\nJacobian test\n")
h = 0.1f0
maxiter = 5
err5 = zeros(Float32, maxiter)
err6 = zeros(Float32, maxiter)
for j=1:maxiter
set_params!(G, θ+h*dθ)
Y_loc, _ = G.forward(X+h*dX)
err5[j] = norm(Y_loc - Y)
err6[j] = norm(Y_loc - Y - h*dY)
print(err5[j], "; ", err6[j], "\n")
global h = h/2f0
end

@test isapprox(err5[end] / (err5[1]/2^(maxiter-1)), 1f0; atol=1f1)
@test isapprox(err6[end] / (err6[1]/4^(maxiter-1)), 1f0; atol=1f1)

# Adjoint test

set_params!(G, θ)
dY, Y, _, _ = G.jacobian(dX, dθ, X)
dY_ = randn(Float32, size(dY))
Expand Down
3 changes: 2 additions & 1 deletion test/test_utils/test_sequential.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Author: Gabrio Rizzuti, [email protected]
# Date: September 2020

using InvertibleNetworks, LinearAlgebra, Test, Statistics
using InvertibleNetworks, LinearAlgebra, Test, Statistics, Random

Random.seed!(1);

###############################################################################
# Initialization
Expand Down

0 comments on commit 1961cec

Please sign in to comment.