Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain ignores Base.show function of custom layer #1929

Closed
fpartl opened this issue Apr 5, 2022 · 9 comments
Closed

Chain ignores Base.show function of custom layer #1929

fpartl opened this issue Apr 5, 2022 · 9 comments

Comments

@fpartl
Copy link
Contributor

fpartl commented Apr 5, 2022

Hello,

I have created custom BiLSTM layer like this:

using Flux

mutable struct BiLSTM{L<:Flux.Recur{T} where T<:Flux.LSTMCell}
    fw::L
    bw::L
end

function BiLSTM(in::Int, out::Int)
    if out % 2 != 0
        throw(ArgumentError("Output size must be divisible by 2."))
    end

    return BiLSTM(
        LSTM(in, out ÷ 2), 
        LSTM(in, out ÷ 2)
    )
end

Flux.@functor BiLSTM
Flux.trainable(b::BiLSTM) = (b.fw, b.bw)

Base.show(io::IO, ::MIME"text/plain", b::BiLSTM) = 
    print(io, "BiLSTM($(size(b.fw.cell.Wi, 2)), $(size(b.fw.cell.Wi, 1)÷2))")

# ...

When I instantiate BiLSTM cell using BiLSTM function the Base.show function is used as expected.

julia> BiLSTM(12, 16)
BiLSTM(12, 16)

But if I create a Chain of BiLSTM cells my custom Base.show function is ignored.

julia> Chain(BiLSTM(12, 16), BiLSTM(16, 16), BiLSTM(16, 16))
Chain(
  BiLSTM(
    Recur(
      LSTMCell(12, 8),                  # 688 parameters
    ),
    Recur(
      LSTMCell(12, 8),                  # 688 parameters
    ),
  ),
  BiLSTM(
    Recur(
      LSTMCell(16, 8),                  # 816 parameters
    ),
    Recur(
      LSTMCell(16, 8),                  # 816 parameters
    ),
  ),
  BiLSTM(
    Recur(
      LSTMCell(16, 8),                  # 816 parameters
    ),
    Recur(
      LSTMCell(16, 8),                  # 816 parameters
    ),
  ),
)         # Total: 30 trainable arrays, 4_640 parameters,
          # plus 12 non-trainable, 96 parameters, summarysize 20.141 KiB.

Is there any way to make this less verbose? Something like this would be nice.

julia> Chain(BiLSTM(12, 16), BiLSTM(16, 16), BiLSTM(16, 16))
Chain(
  BiLSTM(12, 16),                      # 1376 parameters
  BiLSTM(16, 16),                      # 1632 parameters
  BiLSTM(16, 16),                      # 1632 parameters
)         # Total: 30 trainable arrays, 4_640 parameters,
          # plus 12 non-trainable, 96 parameters, summarysize 20.141 KiB.

Thank you for help!

@CarloLucibello
Copy link
Member

You have to implement Flux._bigshow and Flux._layershow for that, see src/layers/show.jl. I think we should document this interface somewhere.

@CarloLucibello
Copy link
Member

What happens if you set

Flux._show_children(::BiLSTM) = ()

?

@mcabbott
Copy link
Member

mcabbott commented Apr 6, 2022

I think you also need to overload 2-arg show (and need not do 3-arg show). You also should not need to overload trainable.

julia> Flux.@functor BiLSTM

julia> Base.show(io::IO, b::BiLSTM) = 
           print(io, "BiLSTM($(size(b.fw.cell.Wi, 2)), $(size(b.fw.cell.Wi, 1)÷2))")

julia> BiLSTM(12, 16)
BiLSTM(12, 16)

julia> Flux._show_children(::BiLSTM) = ()

julia> Chain(BiLSTM(12, 16), BiLSTM(16, 16), BiLSTM(16, 16))
Chain(
  BiLSTM(12, 16),                       # 1_376 parameters, plus 32
  BiLSTM(16, 16),                       # 1_632 parameters, plus 32
  BiLSTM(16, 16),                       # 1_632 parameters, plus 32
)         # Total: 30 trainable arrays, 4_640 parameters,
          # plus 12 non-trainable, 96 parameters, summarysize 20.141 KiB.

It's a bit ugly to need Flux._show_children, this wasn't really intended as a user-facing function, the goal was to be fully automatic... but here we are.

@fpartl
Copy link
Contributor Author

fpartl commented Apr 6, 2022

Thank you for your replies. In Flux v0.12.9 it seems it does not matter if I set trainables or not... the Flux._show_children function is simply not defined. I do not see what I am doing wrong.

julia> using Flux
       
       mutable struct BiLSTM{L<:Flux.Recur{T} where T<:Flux.LSTMCell}
           fw::L
           bw::L
       end
       
       function BiLSTM(in::Int, out::Int)
           if out % 2 != 0
               throw(ArgumentError("Output size must be divisible by 2."))
           end
       
           return BiLSTM(
               LSTM(in, out ÷ 2), 
               LSTM(in, out ÷ 2)
           )
       end
       
       Flux.@functor BiLSTM
       #Flux.trainable(b::BiLSTM) = (b.fw, b.bw)
       Base.show(io::IO, b::BiLSTM) = 
           print(io, "BiLSTM($(size(b.fw.cell.Wi, 2)), $(size(b.fw.cell.Wi, 1)÷2))")
       Flux._show_children(::BiLSTM) = ()
ERROR: UndefVarError: _show_children not defined
Stacktrace:
 [1] getproperty(x::Module, f::Symbol)
   @ Base ./Base.jl:35
 [2] top-level scope
   @ ~/Dokumenty/Projekty/Julia_Smidl/VAEs/src/custom/BiLSTM.jl:24

Functions Flux._bigshow and Flux._layershow in src/layers/show.jl look way to complicated for my level of punctiliousness. Some tutorial in the official documentation would be very nice indeed. 👍

@fpartl
Copy link
Contributor Author

fpartl commented Apr 6, 2022

Wait what?! Inline print seems to be working. Weird... I thing I've tried this yesterday. Now I am slightly confused.

julia> encoder = enc_bilstm_dense(SEGMENT_SIZE, HIDDEN_DIM, BILSTM_LAYERS, LATENT_DIM)
Encoder(
     net: Chain(BiLSTM(2048, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096)), 
       μ: Dense(4096, 8192),
    logσ: Dense(4096, 8192)
)

enc_bilstm_dense is just a factory function which creates Encoder instance with this Base.show function.

using Flux

mutable struct Encoder{N, M, S} <: AbstractEncoder
    net::N
    μ::M
    logσ::S
end

Base.show(io::IO, ::MIME"text/plain", enc::Encoder) =
    print(io, "Encoder(\n     net: $(enc.net), \n       μ: $(enc.μ),\n    logσ: $(enc.logσ)\n)")

@fpartl
Copy link
Contributor Author

fpartl commented Apr 6, 2022

Same song, but an octave higher. Suppose the proposed Encoder is a part of the VAE structure. Something like this.

using Flux

mutable struct VAE
    encoder::Encoder
    decoder::Decoder
    device::Function   # Flux.cpu or Flux.gpu
end

Base.show(io::IO, ::MIME"text/plain", vae::VAE) =
    print(io, "VAE(\n    encoder: $(vae.encoder),\n    decoder: $(vae.decoder)\n    device: $(vae.device)\n)")

Flux.@functor VAE
Flux.trainable(vae::VAE) = (vae.encoder, vae.decoder)

Then the printout is just messy. But! There is my nice short Base.show at the end of the printout.

julia> vae = VAE(encoder, decoder, Flux.cpu)
VAE(
    encoder: Encoder{Chain{NTuple{4, BiLSTM{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Chain(BiLSTM(2048, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096)), Dense(4096, 8192), Dense(4096, 8192)),
    decoder: Decoder

Interesting stuff...

@mcabbott
Copy link
Member

mcabbott commented Apr 6, 2022

Encoder{Chain{NTuple{4, BiLSTM

This is coming from 2-arg show on Encoder. That's what I think you should generally define, not the 3-arg one.

Functions Flux._bigshow and Flux._layershow in src/layers/show.jl look way to complicated

Sadly. These are the functions that recurse into models. But they really are internal, I do not think you should be adding methods to them.

If you want Encoder as the outermost object being shown, to act like Chain and start the recursive printing story, then its 3-arg show needs to call _big_show , as for Chain here: https://github.com/FluxML/Flux.jl/blob/master/src/layers/show.jl#L5-L7

@fpartl
Copy link
Contributor Author

fpartl commented Apr 6, 2022

I've started to define 2-arg show for my custom cells/layers and it's looking good enough. Thank you! 👍

@mcabbott
Copy link
Member

mcabbott commented Apr 7, 2022

Now I made a macro, #1932.

What that doesn't do is provide an easy way to stop the recursion, which was the original request above. In the latest version, Flux.@layer BiLSTM will by default mark this layer as having no further children to expand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants