Skip to content

Commit

Permalink
Merge pull request #66 from slimgroup/cond_3d
Browse files Browse the repository at this point in the history
conditional glow 3d w test and example
  • Loading branch information
rafaelorozco authored Mar 29, 2023
2 parents 1adcc49 + 0b5e476 commit 5c092f9
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 20 deletions.
69 changes: 69 additions & 0 deletions examples/networks/network_conditional_glow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Generative model w/ Glow architecture from Kingma & Dhariwal (2018)
# Network layers are made conditional with CIIN type layers
# Author: Rafael Orozco, [email protected]
# Date: March 2023

using InvertibleNetworks, LinearAlgebra, Flux

device = InvertibleNetworks.CUDA.functional() ? gpu : cpu

nx = 32 # must be multiple of 2^L where L is the multiscale level of the network
ny = 32 # must be multiple of 2^L where L is the multiscale level of the network
n_in = 4
n_cond = 4
n_hidden = 32
batchsize = 5
L = 2 # number of scales
K = 2 # number of flow steps per scale

# Input
X = rand(Float32, nx, ny, n_in, batchsize) |> device;

# Condition
Y = rand(Float32, nx, ny, n_in, batchsize) |> device;

# Glow network
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K) |> device

# Objective function
function loss(G, X, Y)
ZX, ZY, logdet = G.forward(X, Y)
f = .5f0/batchsize*norm(ZX)^2 - logdet
G.backward(1f0./batchsize*ZX, ZX, ZY)
return f
end

# Evaluate loss
f = loss(G, X, Y)

# Update weights
opt = Flux.ADAM()
Params = get_params(G)
for p in Params
Flux.update!(opt, p.data, p.grad)
end
clear_grad!(G)

################ 3D example: To do with 3 spatial dimensions you need to set ndims on network.
############################## or use NetworkConditionalGlow3D
nz = 32

# 3D Input
X_3d = rand(Float32, nx, ny, nz, n_in, batchsize) |> device;

# #dCondition
Y_3d = rand(Float32, nx, ny, nz, n_in, batchsize) |> device;

# 3D Glow network
G_3d = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; ndims=3) |> device

# Evaluate loss
f = loss(G_3d, X_3d, Y_3d)

# Update weights
opt = Flux.ADAM()
Params = get_params(G_3d)
for p in Params
Flux.update!(opt, p.data, p.grad)
end
clear_grad!(G_3d)
6 changes: 3 additions & 3 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false, activ
end

# Constructor from input dimensions
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze_conv=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)

# 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in)
C = Conv1x1(n_in; freeze=freeze_conv)
RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)

return ConditionalLayerGlow(C, RB, logdet, activation)
Expand Down Expand Up @@ -142,7 +142,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA
end

# Backpropagate RB
ΔX2_ΔC = L.RB.backward(cat(L.activation.backward(ΔS, S), ΔT; dims=3), (tensor_cat(X2, C)))
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))
ΔX2 += ΔY2

Expand Down
9 changes: 5 additions & 4 deletions src/networks/invertible_network_conditional_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ end
@Flux.functor NetworkConditionalGlow

# Constructor
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;freeze_conv=false, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
AN = Array{ActNorm}(undef, L, K) # activation normalization
AN_C = ActNorm(n_cond; logdet=false) # activation normalization for condition
CL = Array{ConditionalLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block

if split_scales
Z_dims = fill!(Array{Array}(undef, L-1), [1,1]) #fill in with dummy values so that |> gpu accepts it # save dimensions for inverse/backward pass
channel_factor = 4
channel_factor = 2^(ndims)
else
Z_dims = nothing
channel_factor = 1
Expand All @@ -91,7 +91,7 @@ function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=false
n_cond *= channel_factor # squeeze if split_scales is turned on
for j=1:K
AN[i, j] = ActNorm(n_in; logdet=true)
CL[i, j] = ConditionalLayerGlow(n_in, n_cond, n_hidden; rb_activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, activation=activation, ndims=ndims)
CL[i, j] = ConditionalLayerGlow(n_in, n_cond, n_hidden;freeze_conv=freeze_conv, rb_activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, activation=activation, ndims=ndims)
end
(i < L && split_scales) && (n_in = Int64(n_in/2)) # split
end
Expand Down Expand Up @@ -169,10 +169,11 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA
end

if G.split_scales
ΔC_total = G.squeezer.inverse(ΔC_total)
C = G.squeezer.inverse(C)
ΔC_total = G.squeezer.inverse(ΔC_total)
X = G.squeezer.inverse(X)
ΔX = G.squeezer.inverse(ΔX)

end
end

Expand Down
2 changes: 1 addition & 1 deletion test/test_layers/test_coupling_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,4 @@ dY_ = randn(Float32, size(dY))
logdet ? ((dX_, dθ_, _, _) = HL.adjointJacobian(dY_, Y)) : ((dX_, dθ_, _) = HL.adjointJacobian(dY_, Y))
a = dot(dY, dY_)
b = dot(dX, dX_)+dot(dθ, dθ_)
@test isapprox(a, b; rtol=1f-3)
@test isapprox(a, b; rtol=1f-3)
130 changes: 118 additions & 12 deletions test/test_networks/test_conditional_glow_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,126 @@ Random.seed!(2);
# Define network
nx = 32
ny = 32
nz = 32
n_in = 2
n_cond = 2
n_hidden = 4
batchsize = 2
L = 2
K = 2
split_scales = true
N = (nx,ny)
########################################### Test with split_scales = true N = (nx,ny) #########################
# Invertibility

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

###################################################################################################
# Test gradients are set and cleared
G.backward(Y, Y, Cond)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, L*K*10+2)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, 0)


###################################################################################################
# Gradient test

function loss(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y, ZC)
return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)
X0 = rand(Float32, N..., n_in, batchsize)
Cond0 = rand(Float32, N..., n_cond, batchsize)

dX = X - X0

f0, ΔX = loss(G, X0, Cond0)[1:2]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss(G, X0 + h*dX, Cond0)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)


# Gradient test w.r.t. parameters
X = rand(Float32, N..., n_in, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv

f = loss(G0, X, Cond)[1]
err3[j] = abs(f - f0)
err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
print(err3[j], "; ", err4[j], "\n")
global h = h/2f0
end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)


########################################### Test with split_scales = false #########################
N = (nx,ny,nz)
########################################### Test with split_scales = true N = (nx,ny,nz) #########################
# Invertibility

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K)
X = rand(Float32, nx, ny, n_in, batchsize)
Cond = rand(Float32, nx, ny, n_cond, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes
Expand Down Expand Up @@ -61,11 +167,11 @@ function loss(G, X, Cond)
end

# Gradient test w.r.t. input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K)
X = rand(Float32, nx, ny, n_in, batchsize)
Cond = rand(Float32, nx, ny, n_cond, batchsize)
X0 = rand(Float32, nx, ny, n_in, batchsize)
Cond0 = rand(Float32, nx, ny, n_cond, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)
X0 = rand(Float32, N..., n_in, batchsize)
Cond0 = rand(Float32, N..., n_cond, batchsize)

dX = X - X0

Expand All @@ -89,9 +195,9 @@ end


# Gradient test w.r.t. parameters
X = rand(Float32, nx, ny, n_in, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K)
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K)
X = rand(Float32, N..., n_in, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
Expand Down

0 comments on commit 5c092f9

Please sign in to comment.