-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chain ignores Base.show function of custom layer #1929
Comments
You have to implement |
What happens if you set Flux._show_children(::BiLSTM) = () ? |
I think you also need to overload 2-arg show (and need not do 3-arg show). You also should not need to overload
It's a bit ugly to need |
Thank you for your replies. In Flux v0.12.9 it seems it does not matter if I set julia> using Flux
mutable struct BiLSTM{L<:Flux.Recur{T} where T<:Flux.LSTMCell}
fw::L
bw::L
end
function BiLSTM(in::Int, out::Int)
if out % 2 != 0
throw(ArgumentError("Output size must be divisible by 2."))
end
return BiLSTM(
LSTM(in, out ÷ 2),
LSTM(in, out ÷ 2)
)
end
Flux.@functor BiLSTM
#Flux.trainable(b::BiLSTM) = (b.fw, b.bw)
Base.show(io::IO, b::BiLSTM) =
print(io, "BiLSTM($(size(b.fw.cell.Wi, 2)), $(size(b.fw.cell.Wi, 1)÷2))")
Flux._show_children(::BiLSTM) = ()
ERROR: UndefVarError: _show_children not defined
Stacktrace:
[1] getproperty(x::Module, f::Symbol)
@ Base ./Base.jl:35
[2] top-level scope
@ ~/Dokumenty/Projekty/Julia_Smidl/VAEs/src/custom/BiLSTM.jl:24 Functions |
Wait what?! Inline print seems to be working. Weird... I thing I've tried this yesterday. Now I am slightly confused. julia> encoder = enc_bilstm_dense(SEGMENT_SIZE, HIDDEN_DIM, BILSTM_LAYERS, LATENT_DIM)
Encoder(
net: Chain(BiLSTM(2048, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096)),
μ: Dense(4096, 8192),
logσ: Dense(4096, 8192)
)
using Flux
mutable struct Encoder{N, M, S} <: AbstractEncoder
net::N
μ::M
logσ::S
end
Base.show(io::IO, ::MIME"text/plain", enc::Encoder) =
print(io, "Encoder(\n net: $(enc.net), \n μ: $(enc.μ),\n logσ: $(enc.logσ)\n)") |
Same song, but an octave higher. Suppose the proposed using Flux
mutable struct VAE
encoder::Encoder
decoder::Decoder
device::Function # Flux.cpu or Flux.gpu
end
Base.show(io::IO, ::MIME"text/plain", vae::VAE) =
print(io, "VAE(\n encoder: $(vae.encoder),\n decoder: $(vae.decoder)\n device: $(vae.device)\n)")
Flux.@functor VAE
Flux.trainable(vae::VAE) = (vae.encoder, vae.decoder) Then the printout is just messy. But! There is my nice short julia> vae = VAE(encoder, decoder, Flux.cpu)
VAE(
encoder: Encoder{Chain{NTuple{4, BiLSTM{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}(Chain(BiLSTM(2048, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096), BiLSTM(4096, 4096)), Dense(4096, 8192), Dense(4096, 8192)),
decoder: Decoder Interesting stuff... |
This is coming from 2-arg show on Encoder. That's what I think you should generally define, not the 3-arg one.
Sadly. These are the functions that recurse into models. But they really are internal, I do not think you should be adding methods to them. If you want |
I've started to define 2-arg |
Now I made a macro, #1932.
|
Hello,
I have created custom BiLSTM layer like this:
When I instantiate
BiLSTM
cell usingBiLSTM
function theBase.show
function is used as expected.But if I create a
Chain
of BiLSTM cells my customBase.show
function is ignored.Is there any way to make this less verbose? Something like this would be nice.
Thank you for help!
The text was updated successfully, but these errors were encountered: