Skip to content

Commit

Permalink
Merge pull request #59 from slimgroup/conditional-glow
Browse files Browse the repository at this point in the history
Conditional glow network and flexible residual block
  • Loading branch information
mloubout authored Aug 17, 2022
2 parents a068423 + 25e999f commit 033a4a1
Show file tree
Hide file tree
Showing 11 changed files with 803 additions and 192 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
version = "2.1.5"
version = "2.2.0"

This comment has been minimized.

Copy link
@mloubout

mloubout Aug 17, 2022

Author Member

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
2 changes: 2 additions & 0 deletions src/InvertibleNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ include("networks/invertible_network_glow.jl") # Glow: Dinh et al. (2017), King
include("networks/invertible_network_hyperbolic.jl") # Hyperbolic: Lensink et al. (2019)

# Conditional layers and nets
include("conditional_layers/conditional_layer_glow.jl")
include("conditional_layers/conditional_layer_hint.jl")
include("networks/invertible_network_conditional_glow.jl")
include("networks/invertible_network_conditional_hint.jl")
include("networks/invertible_network_conditional_hint_multiscale.jl")

Expand Down
153 changes: 153 additions & 0 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Conditional coupling layer based on GLOW and cIIN
# Date: January 2022

export ConditionalLayerGlow, ConditionalLayerGlow3D


"""
CL = ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false)
or
CL = ConditionalLayerGlow(n_in, n_cond, n_hidden; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, ndims=2) (2D)
CL = ConditionalLayerGlow(n_in, n_cond, n_hidden; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, ndims=3) (3D)
CL = ConditionalLayerGlowGlow3D(n_in, n_cond, n_hidden; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false) (3D)
Create a Real NVP-style invertible conditional coupling layer based on 1x1 convolutions and a residual block.
*Input*:
- `C::Conv1x1`: 1x1 convolution layer
- `RB::ResidualBlock`: residual block layer consisting of 3 convolutional layers with ReLU activations.
- `logdet`: bool to indicate whether to compte the logdet of the layer
or
- `n_in`,`n_out`, `n_hidden`: number of channels for: passive input, conditioned input and hidden layer
- `k1`, `k2`: kernel size of convolutions in residual block. `k1` is the kernel of the first and third
operator, `k2` is the kernel size of the second operator.
- `p1`, `p2`: padding for the first and third convolution (`p1`) and the second convolution (`p2`)
- `s1`, `s2`: stride for the first and third convolution (`s1`) and the second convolution (`s2`)
- `ndims` : number of dimensions
*Output*:
- `CL`: Invertible Real NVP conditional coupling layer.
*Usage:*
- Forward mode: `Y, logdet = CL.forward(X, C)` (if constructed with `logdet=true`)
- Inverse mode: `X = CL.inverse(Y, C)`
- Backward mode: `ΔX, X = CL.backward(ΔY, Y, C)`
*Trainable parameters:*
- None in `CL` itself
- Trainable parameters in residual block `CL.RB` and 1x1 convolution layer `CL.C`
See also: [`Conv1x1`](@ref), [`ResidualBlock`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref)
"""
struct ConditionalLayerGlow <: NeuralNetLayer
C::Conv1x1
RB::ResidualBlock
logdet::Bool
activation::ActivationFunction
end

@Flux.functor ConditionalLayerGlow

# Constructor from 1x1 convolution and residual block
function ConditionalLayerGlow(C::Conv1x1, RB::ResidualBlock; logdet=false, activation::ActivationFunction=SigmoidLayer())
RB.fan == false && throw("Set ResidualBlock.fan == true")
return ConditionalLayerGlow(C, RB, logdet, activation)
end

# Constructor from input dimensions
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)

# 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in)
RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)

return ConditionalLayerGlow(C, RB, logdet, activation)
end

ConditionalLayerGlow3D(args...;kw...) = ConditionalLayerGlow(args...; kw..., ndims=3)

# Forward pass: Input X, Output Y
function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow) where {T,N}

X_ = L.C.forward(X)
X1, X2 = tensor_split(X_)

Y2 = copy(X2)

# Cat conditioning variable C into network input
logS_T = L.RB.forward(tensor_cat(X2,C))
logS, log_T = tensor_split(logS_T)

Sm = L.activation.forward(logS)
Tm = log_T
Y1 = Sm.*X1 + Tm

Y = tensor_cat(Y1, Y2)

L.logdet == true ? (return Y, glow_logdet_forward(Sm)) : (return Y)
end

# Inverse pass: Input Y, Output X
function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow; save=false) where {T,N}

Y1, Y2 = tensor_split(Y)

X2 = copy(Y2)
logS_T = L.RB.forward(tensor_cat(X2,C))
logS, log_T = tensor_split(logS_T)

Sm = L.activation.forward(logS)
Tm = log_T
X1 = (Y1 - Tm) ./ (Sm .+ eps(T)) # add epsilon to avoid division by 0

X_ = tensor_cat(X1, X2)
X = L.C.inverse(X_)

save == true ? (return X, X1, X2, Sm) : (return X)
end

# Backward pass: Input (ΔY, Y), Output (ΔX, X)
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow;) where {T,N}

# Recompute forward state
X, X1, X2, S = inverse(Y, C, L; save=true)

# Backpropagate residual
ΔY1, ΔY2 = tensor_split(ΔY)
ΔT = copy(ΔY1)
ΔS = ΔY1 .* X1
ΔX1 = ΔY1 .* S

if L.logdet
ΔS -= glow_logdet_backward(S)
end

# Backpropagate RB
ΔX2_ΔC = L.RB.backward(cat(L.activation.backward(ΔS, S), ΔT; dims=3), (tensor_cat(X2, C)))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))
ΔX2 += ΔY2

# Backpropagate 1x1 conv
ΔX = L.C.inverse((tensor_cat(ΔX1, ΔX2), tensor_cat(X1, X2)))[1]

return ΔX, X, ΔC
end
53 changes: 31 additions & 22 deletions src/layers/layer_residual_block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ or
*Input*:
- `n_in`, `n_hidden`: number of input and hidden channels
- `n_in`: number of input channels
- `n_hidden`: number of hidden channels
- `n_out`: number of ouput channels
- `activation`: activation type between conv layers and final output
- `k1`, `k2`: kernel size of convolutions in residual block. `k1` is the kernel of the first and third
operator, `k2` is the kernel size of the second operator.
Expand Down Expand Up @@ -67,6 +73,7 @@ struct ResidualBlock <: NeuralNetLayer
fan::Bool
strides
pad
activation::ActivationFunction
end

@Flux.functor ResidualBlock
Expand All @@ -75,22 +82,24 @@ end
# Constructors

# Constructor
function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
function ResidualBlock(n_in, n_hidden; n_out=nothing, activation::ActivationFunction=ReLUlayer(), k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
# default/legacy behaviour
isnothing(n_out) && (n_out = 2*n_in)

k1 = Tuple(k1 for i=1:ndims)
k2 = Tuple(k2 for i=1:ndims)
# Initialize weights
W1 = Parameter(glorot_uniform(k1..., n_in, n_hidden))
W2 = Parameter(glorot_uniform(k2..., n_hidden, n_hidden))
W3 = Parameter(glorot_uniform(k1..., 2*n_in, n_hidden))
W3 = Parameter(glorot_uniform(k1..., n_out, n_hidden))
b1 = Parameter(zeros(Float32, n_hidden))
b2 = Parameter(zeros(Float32, n_hidden))

return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2), activation)
end

# Constructor for given weights
function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)
function ResidualBlock(W1, W2, W3, b1, b2; activation::ActivationFunction=ReLUlayer(), p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2)

# Make weights parameters
W1 = Parameter(W1)
Expand All @@ -99,7 +108,7 @@ function ResidualBlock(W1, W2, W3, b1, b2; p1=1, p2=1, s1=1, s2=1, fan=false, nd
b1 = Parameter(b1)
b2 = Parameter(b2)

return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2))
return ResidualBlock(W1, W2, W3, b1, b2, fan, (s1, s2), (p1, p2),activation)
end

ResidualBlock3D(args...; kw...) = ResidualBlock(args...; kw..., ndims=3)
Expand All @@ -111,17 +120,17 @@ function forward(X1::AbstractArray{T, N}, RB::ResidualBlock; save=false) where {
inds =[i!=(N-1) ? 1 : Colon() for i=1:N]

Y1 = conv(X1, RB.W1.data; stride=RB.strides[1], pad=RB.pad[1]) .+ reshape(RB.b1.data, inds...)
X2 = ReLU(Y1)
X2 = RB.activation.forward(Y1)

Y2 = X2 + conv(X2, RB.W2.data; stride=RB.strides[2], pad=RB.pad[2]) .+ reshape(RB.b2.data, inds...)
X3 = ReLU(Y2)
X3 = RB.activation.forward(Y2)

cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
# Return if only recomputing state
save && (return Y1, Y2, Y3)
# Finish forward
RB.fan == true ? (return ReLU(Y3)) : (return GaLU(Y3))
RB.fan == true ? (return RB.activation.forward(Y3)) : (return GaLU(Y3))
end

# Backward
Expand All @@ -135,21 +144,21 @@ function backward(ΔX4::AbstractArray{T, N}, X1::AbstractArray{T, N},

# Cdims
cdims2 = DenseConvDims(Y2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])
cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
cdims3 = DCDims(X1, RB.W3.data; stride=RB.strides[1], padding=RB.pad[1])

# Backpropagate residual ΔX4 and compute gradients
RB.fan == true ? (ΔY3 = ReLUgrad(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
RB.fan == true ? (ΔY3 = RB.activation.backward(ΔX4, Y3)) : (ΔY3 = GaLUgrad(ΔX4, Y3))
ΔX3 = conv(ΔY3, RB.W3.data, cdims3)
ΔW3 = ∇conv_filter(ΔY3, ReLU(Y2), cdims3)
ΔW3 = ∇conv_filter(ΔY3, RB.activation.forward(Y2), cdims3)

ΔY2 = ReLUgrad(ΔX3, Y2)
ΔY2 = RB.activation.backward(ΔX3, Y2)
ΔX2 = ∇conv_data(ΔY2, RB.W2.data, cdims2) + ΔY2
ΔW2 = ∇conv_filter(ReLU(Y1), ΔY2, cdims2)
ΔW2 = ∇conv_filter(RB.activation.forward(Y1), ΔY2, cdims2)
Δb2 = sum(ΔY2, dims=dims)[inds...]

cdims1 = DenseConvDims(X1, RB.W1.data; stride=RB.strides[1], padding=RB.pad[1])

ΔY1 = ReLUgrad(ΔX2, Y1)
ΔY1 = RB.activation.backward(ΔX2, Y1)
ΔX1 = ∇conv_data(ΔY1, RB.W1.data, cdims1)
ΔW1 = ∇conv_filter(X1, ΔY1, cdims1)
Δb1 = sum(ΔY1, dims=dims)[inds...]
Expand Down Expand Up @@ -177,22 +186,22 @@ function jacobian(ΔX1::AbstractArray{T, N}, Δθ::Array{Parameter, 1},

Y1 = conv(X1, RB.W1.data, cdims1) .+ reshape(RB.b1.data, inds...)
ΔY1 = conv(ΔX1, RB.W1.data, cdims1) + conv(X1, Δθ[1].data, cdims1) .+ reshape(Δθ[4].data, inds...)
X2 = ReLU(Y1)
ΔX2 = ReLUgrad(ΔY1, Y1)
X2 = RB.activation.forward(Y1)
ΔX2 = RB.activation.backward(ΔY1, Y1)

cdims2 = DenseConvDims(X2, RB.W2.data; stride=RB.strides[2], padding=RB.pad[2])

Y2 = X2 + conv(X2, RB.W2.data, cdims2) .+ reshape(RB.b2.data, inds...)
ΔY2 = ΔX2 + conv(ΔX2, RB.W2.data, cdims2) + conv(X2, Δθ[2].data, cdims2) .+ reshape(Δθ[5].data, inds...)
X3 = ReLU(Y2)
ΔX3 = ReLUgrad(ΔY2, Y2)
X3 = RB.activation.forward(Y2)
ΔX3 = RB.activation.backward(ΔY2, Y2)

cdims3 = DCDims(X1, RB.W3.data; nc=2*size(X1, N-1), stride=RB.strides[1], padding=RB.pad[1])
Y3 = ∇conv_data(X3, RB.W3.data, cdims3)
ΔY3 = ∇conv_data(ΔX3, RB.W3.data, cdims3) + ∇conv_data(X3, Δθ[3].data, cdims3)
if RB.fan == true
X4 = ReLU(Y3)
ΔX4 = ReLUgrad(ΔY3, Y3)
X4 = RB.activation.forward(Y3)
ΔX4 = RB.activation.backward(ΔY3, Y3)
else
ΔX4, X4 = GaLUjacobian(ΔY3, Y3)
end
Expand Down
Loading

1 comment on commit 033a4a1

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/66442

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.2.0 -m "<description of version>" 033a4a13db1e3b7395d560313cf9fdf2e10640cb
git push origin v2.2.0

Please sign in to comment.