Skip to content

Commit

Permalink
Merge pull request #78 from itan1/add-leakyrelu
Browse files Browse the repository at this point in the history
Add leakyrelu
  • Loading branch information
DrChainsaw authored Jul 10, 2023
2 parents 8e1ccba + 1f2ea0d commit 6c15b18
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Gemm
GlobalAveragePool
GlobalMaxPool
LSTM
LeakyRelu
MatMul
MaxPool
Mul
Expand Down
6 changes: 6 additions & 0 deletions src/deserialize/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ constant(::Val{:value}, val) = val
actfuns[:Relu] = params -> Flux.relu
actfuns[:Sigmoid] = params -> Flux.σ

actfuns[:LeakyRelu] = function(params)
α = get(params, :alpha, 0.01f0)
return x -> Flux.leakyrelu(x, oftype(x, α))
end
rnnactfuns[:LeakyRelu] = (ind, params) -> actfuns[:LeakyRelu](Dict(:alpha => get(params, :activation_alpha, ntuple(i -> 0.01f0, ind))[ind]))

actfuns[:Elu] = function(params)
α = get(params, :alpha, 1)
return x -> Flux.elu(x, oftype(x, α))
Expand Down
1 change: 1 addition & 0 deletions src/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ function attribfun(fhshape, optype, pps::AbstractProbe...; attributes = ONNX.Att
end

Flux.relu(pp::AbstractProbe) = attribfun(identity, "Relu", pp)
Flux.leakyrelu(pp::AbstractProbe, α=0.01f0) = attribfun(identity, "LeakyRelu", pp; attributes = [ONNX.AttributeProto("alpha", α)])
Flux.elu(pp::AbstractProbe, α=1f0) = attribfun(identity, "Elu", pp; attributes = [ONNX.AttributeProto("alpha", α)])
Flux.selu(pp::AbstractProbe) = attribfun(identity, "Selu", pp)
Flux.selu(pp::AbstractProbe, γ, α) = attribfun(identity, "Selu", pp; attributes = ONNX.AttributeProto.(["gamma", "alpha"], [γ, α]))
Expand Down
17 changes: 13 additions & 4 deletions test/deserialize/Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ git-tree-sha1 = "8be97aa969ebdbe7599798d511a7790eba0697f2"
git-tree-sha1 = "41eca620fb09f7d90ec9b875a80388566baadada"

[test_div]
git-tree-sha1 = "430fd9135e60076904f717971bd174b0f16e1c54"
git-tree-sha1 = "57dd66f7274aac0e2a462e49dadf5a551c4e5e80"

[test_dropout_default]
git-tree-sha1 = "70fe420142b8d29b708578e4b6f2929e6907cb4c"
Expand Down Expand Up @@ -169,14 +169,23 @@ git-tree-sha1 = "377710458916cc790bb7eec00c8e3f0719680cf8"
[test_globalmaxpool_precomputed]
git-tree-sha1 = "6d72b58370176351d46937ca3df65ba2fd114f04"

[test_leakyrelu]
git-tree-sha1 = "07afe319b71db2cb6bc295ff9409482721473817"

[test_leakyrelu_default]
git-tree-sha1 = "2751dbd14e5feaf6c59798e84ae9e7d9700240b6"

[test_leakyrelu_example]
git-tree-sha1 = "b7e814cb5b5d6d538db1d6af49c02b786cb0036e"

[test_lstm_defaults]
git-tree-sha1 = "c8b0d06dc9733222906bb6471c39a9c41270d149"

[test_lstm_with_initial_bias]
git-tree-sha1 = "19fe9305067a4225c6dd76264bb342cf966546ae"

[test_matmul_2d]
git-tree-sha1 = "481de0ea5b1fb4692f10920215ce701df1b7ba09"
git-tree-sha1 = "3008bfa77da6160c406f6dae414b19861fef9f13"

[test_maxpool_1d_default]
git-tree-sha1 = "9b0a2b97518eb68122276b242313529582e4be95"
Expand Down Expand Up @@ -269,10 +278,10 @@ git-tree-sha1 = "44c37442a35def50d4e1230cdcff8a986899d18d"
git-tree-sha1 = "dc4a6180985e796aca6997ae79137fee6f9c05e9"

[test_sigmoid]
git-tree-sha1 = "2bb16571d0809d1e6216a1b6eb5b27d332fac4f0"
git-tree-sha1 = "6f0e41cd8b1498f3c60b3d00b9558bf311217bf8"

[test_sigmoid_example]
git-tree-sha1 = "710a8b85b3a7e9e301af21f8f1c0e7b087536ba5"
git-tree-sha1 = "b8b836fd3cb97d2801777c03e98d6cd41bc17b2d"

[test_softmax_axis_0]
git-tree-sha1 = "4f090b0b0f540b5f133176e4cf8a118c24db1886"
Expand Down
3 changes: 3 additions & 0 deletions test/deserialize/deserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ end
(name="test_elu_default", ninputs=1, noutputs=1),
(name="test_elu_example", ninputs=1, noutputs=1),
(name="test_relu", ninputs=1, noutputs=1),
(name="test_leakyrelu", ninputs=1, noutputs=1),
(name="test_leakyrelu_default", ninputs=1, noutputs=1),
(name="test_leakyrelu_example", ninputs=1, noutputs=1),
(name="test_selu", ninputs=1, noutputs=1),
(name="test_selu_default", ninputs=1, noutputs=1),
(name="test_selu_example", ninputs=1, noutputs=1),
Expand Down
3 changes: 3 additions & 0 deletions test/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

@testset "Paramfree op $(tc.op) attrs: $(pairs(tc.attr))" for tc in (
(op=:Relu, attr = Dict(), fd=actfuns),
(op=:LeakyRelu, attr = Dict(), fd=actfuns),
(op=:LeakyRelu, attr = Dict(:alpha => 0.05f0), fd=actfuns),
(op=:Elu, attr = Dict(), fd=actfuns),
(op=:Elu, attr = Dict(:alpha => 0.5f0), fd=actfuns),
(op=:Selu, attr = Dict(), fd=actfuns),
Expand Down Expand Up @@ -154,6 +156,7 @@

@testset "Layer with activation function $actfun" for actfun in (
relu,
leakyrelu,
elu,
selu,
tanh,
Expand Down

0 comments on commit 6c15b18

Please sign in to comment.