Skip to content

Commit

Permalink
Merge branch 'master' into add-instancenorm
Browse files Browse the repository at this point in the history
  • Loading branch information
DrChainsaw authored Jul 11, 2023
2 parents 9b5db38 + ee7f5f0 commit 8902277
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ BatchNormalization
Concat
Constant
Conv
ConvTranspose
Div
Dropout
Elu
Expand Down
6 changes: 3 additions & 3 deletions src/ONNXNaiveNASflux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ using Flux: params
using NaiveNASflux
using NaiveNASflux: weights, bias
using NaiveNASflux: indim, outdim, actdim, actrank, layertype, wrapped
using NaiveNASflux: FluxLayer, FluxParLayer, FluxNoParLayer, FluxDense, FluxConvolutional, FluxConv, FluxBatchNorm,
FluxInstanceNorm, FluxRecurrent, FluxRnn, FluxLstm, FluxGru, FluxTransparentLayer, FluxPoolLayer,
FluxDropOut, Flux2D, GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
using NaiveNASflux: FluxLayer, FluxParLayer, FluxNoParLayer, FluxDense, FluxConvolutional, FluxConv, FluxConvTranspose,
FluxBatchNorm, FluxInstanceNorm, FluxRecurrent, FluxRnn, FluxLstm, FluxGru, FluxTransparentLayer,
FluxPoolLayer, FluxDropOut, Flux2D, GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
using Setfield
using Statistics
import Pkg
Expand Down
11 changes: 11 additions & 0 deletions src/deserialize/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ actlayers[:Conv] = function(params, weight::AbstractArray{T, N}, bias=false) whe
end
fluxlayertypes[:Conv] = (weight, bias=nothing) -> FluxConv{length(size(weight))-2}()

actlayers[:ConvTranspose] = function(params, weight::AbstractArray{T, N}, bias=false) where {T, N}
a,_,p,s,d = akpsd(params)

@assert get(params, :group, 1) == 1 "Group size not supported!" # TODO
@assert !haskey(params, :output_shape) "ConvTranspose: output_shape not supported"
@assert !haskey(params, :output_padding) "ConvTranspose: output_padding not supported"

return ConvTranspose(flipweights(FluxConvTranspose{N-2}(), weight), bias, a, pad=p, stride=s, dilation=d)
end
fluxlayertypes[:ConvTranspose] = (weight, bias=nothing) -> FluxConvTranspose{length(size(weight))-2}()

biasarray(b::Bool, esize, β) = b
biasarray(b::AbstractArray, esize, β) = length(b) === 1 ? repeat.* vec(b), esize) : β .* reshape(b, :)
biasarray(b::Number, esize, β) = repeat([β * b], esize)
Expand Down
1 change: 1 addition & 0 deletions src/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ function(l::Flux.Dense)(pp::AbstractProbe)
end

(l::Flux.Conv)(pp::AbstractProbe) = weightlayer(layertype(l), l, pp, "Conv"; attributes = attribs(l))
(l::Flux.ConvTranspose)(pp::AbstractProbe) = weightlayer(layertype(l), l, pp, "ConvTranspose"; attributes = attribs(l))

attribs(l) = attribs(layertype(l), l)
attribs(::FluxConvolutional{N}, l) where N = ONNX.AttributeProto.([ "pads", "strides", "dilations"], [padexpand(Val(N), l.pad), reverse(l.stride), reverse(l.dilation)])
Expand Down
14 changes: 14 additions & 0 deletions src/shapes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ function outshape(::FluxConvolutional{N}, l, s::Tuple) where N
return (o..., nout(l), s[N+2])
end

function outshape(::FluxConvTranspose{N}, l, s::Tuple) where N
assertshape(s, N+2, l)
assertsize(s[N+1], nin(l)[], l)
p = length(l.pad) == N ? 2 .* l.pad : l.pad[1:2:end] .+ l.pad[2:2:end]
k = size(weights(l))[1:N]
d = l.dilation
stride = l.stride

o = map(zip(1:N, s)) do (i, si)
aggshape(x -> (stride[i] * (x - 1) + ((k[i] - 1) * d[i] + 1) - p[i]), si)
end
return (o..., nout(l), s[N+2])
end

outshape(l::Union{Flux.MaxPool{N}, Flux.MeanPool{N}}, ::Missing) where N = outshape(l, ntuple(i->missing, N+2))
function outshape(l::Union{Flux.MaxPool{N}, Flux.MeanPool{N}}, s::Tuple) where N
assertshape(s, N+2, l)
Expand Down
27 changes: 27 additions & 0 deletions test/deserialize/Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,33 @@ git-tree-sha1 = "8be97aa969ebdbe7599798d511a7790eba0697f2"
[test_conv_with_strides_padding]
git-tree-sha1 = "41eca620fb09f7d90ec9b875a80388566baadada"

[test_convtranspose]
git-tree-sha1 = "941508204a2655148ca271b54790c27911fbc8f8"

[test_convtranspose_1d]
git-tree-sha1 = "c62ef0df9eed662b0f673def672291e2b883d9bf"

[test_convtranspose_3d]
git-tree-sha1 = "78f42b71d517b85ede369a446816d64007777343"

[test_convtranspose_dilations]
git-tree-sha1 = "11ef17d087338b3c98100b62cbe5f387a7b9b026"

[test_convtranspose_kernel_shape]
git-tree-sha1 = "216fb10bf7fa2f66e8da0913a4919910bd6723d8"

[test_convtranspose_output_shape]
git-tree-sha1 = "c2c636fc9333cfeb706761d015fd49634c710dc6"

[test_convtranspose_pad]
git-tree-sha1 = "bcfa9e154fc3b4b9bb396586a753de36d17cac59"

[test_convtranspose_pads]
git-tree-sha1 = "2b996154c62040de848d2fbc83e551651e1fbd67"

[test_convtranspose_with_kernel]
git-tree-sha1 = "7e2ea896245f3f90dae5dec81289b75f198bb07b"

[test_div]
git-tree-sha1 = "57dd66f7274aac0e2a462e49dadf5a551c4e5e80"

Expand Down
9 changes: 9 additions & 0 deletions test/deserialize/deserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ end
(name="test_conv_with_strides_and_asymmetric_padding", ninputs=2, noutputs=1),
(name="test_conv_with_strides_no_padding", ninputs=2, noutputs=1),
(name="test_conv_with_strides_padding", ninputs=2, noutputs=1),
(name="test_convtranspose", ninputs=2, noutputs=1),
(name="test_convtranspose_1d", ninputs=2, noutputs=1),
(name="test_convtranspose_3d", ninputs=2, noutputs=1),
(name="test_convtranspose_dilations", ninputs=2, noutputs=1),
#(name="test_convtranspose_kernel_shape", ninputs=2, noutputs=1), Not supported!
#(name="test_convtranspose_output_shape", ninputs=2, noutputs=1), Not supported!
#(name="test_convtranspose_pad", ninputs=2, noutputs=1), Not supported!
(name="test_convtranspose_pads", ninputs=2, noutputs=1),
#(name="test_convtranspose_with_kernel", ninputs=2, noutputs=1), Not supported!
(name="test_dropout_default", ninputs=1, noutputs=1),
(name="test_dropout_random", ninputs=1, noutputs=1),
#(name="test_gemm_all_attributes", ninputs=3, noutputs=1), Not supported!
Expand Down
2 changes: 2 additions & 0 deletions test/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@
(layer=Dense(3,4, actfun), indata=reshape(collect(Float32, 1:12), :, 4) .- 3),
(layer=Conv((1,2), 3=>4, actfun; pad=(2,1), stride=(1,2), dilation=3), indata=reshape(collect(Float32, 1:2*3*9*9), 9,9,3,2) .- 5),
(layer=Conv((2,3), 3=>4, actfun; pad=(1,2,3,4), stride=(1,2), dilation=3), indata=reshape(collect(Float32, 1:2*3*9*9), 9,9,3,2) .- 10),
(layer=ConvTranspose((3,3), 3=>4, actfun), indata=reshape(collect(Float32, 1:2*3*9*9), 9,9,3,2) .- 10),
(layer=ConvTranspose((2,3), 3=>4, actfun; pad=(1,2,3,4), stride=(1,2), dilation=3), indata=reshape(collect(Float32, 1:2*3*9*9), 9,9,3,2) .- 10),
)

inprobe = NodeProbe("input", genname, shape(layertype(tc.layer), nin(tc.layer)))
Expand Down

0 comments on commit 8902277

Please sign in to comment.