From aa65466c26f719cee9074a583aa31d7990b15582 Mon Sep 17 00:00:00 2001 From: Isabelle Tan Date: Sun, 9 Jul 2023 13:28:49 +0200 Subject: [PATCH 1/2] Add LeakyRelu --- README.md | 1 + src/deserialize/ops.jl | 6 ++++++ src/serialize/serialize.jl | 1 + test/deserialize/Artifacts.toml | 9 +++++++++ test/deserialize/deserialize.jl | 3 +++ test/serialize/serialize.jl | 3 +++ 6 files changed, 23 insertions(+) diff --git a/README.md b/README.md index f515937..b3925b4 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ Gemm GlobalAveragePool GlobalMaxPool LSTM +LeakyRelu MatMul MaxPool Mul diff --git a/src/deserialize/ops.jl b/src/deserialize/ops.jl index 7a04f13..1005331 100644 --- a/src/deserialize/ops.jl +++ b/src/deserialize/ops.jl @@ -51,6 +51,12 @@ constant(::Val{:value}, val) = val actfuns[:Relu] = params -> Flux.relu actfuns[:Sigmoid] = params -> Flux.σ +actfuns[:LeakyRelu] = function(params) + α = get(params, :alpha, 0.01f0) + return x -> Flux.leakyrelu(x, oftype(x, α)) +end +rnnactfuns[:LeakyRelu] = (ind, params) -> actfuns[:LeakyRelu](Dict(:alpha => get(params, :activation_alpha, ntuple(i -> 0.01f0, ind))[ind])) + actfuns[:Elu] = function(params) α = get(params, :alpha, 1) return x -> Flux.elu(x, oftype(x, α)) diff --git a/src/serialize/serialize.jl b/src/serialize/serialize.jl index 4e90244..3faac18 100644 --- a/src/serialize/serialize.jl +++ b/src/serialize/serialize.jl @@ -365,6 +365,7 @@ function attribfun(fhshape, optype, pps::AbstractProbe...; attributes = ONNX.Att end Flux.relu(pp::AbstractProbe) = attribfun(identity, "Relu", pp) +Flux.leakyrelu(pp::AbstractProbe, α=0.01f0) = attribfun(identity, "LeakyRelu", pp; attributes = [ONNX.AttributeProto("alpha", α)]) Flux.elu(pp::AbstractProbe, α=1f0) = attribfun(identity, "Elu", pp; attributes = [ONNX.AttributeProto("alpha", α)]) Flux.selu(pp::AbstractProbe) = attribfun(identity, "Selu", pp) Flux.selu(pp::AbstractProbe, γ, α) = attribfun(identity, "Selu", pp; attributes = ONNX.AttributeProto.(["gamma", "alpha"], [γ, α])) diff --git a/test/deserialize/Artifacts.toml b/test/deserialize/Artifacts.toml index 865ab3b..3e5ed0b 100644 --- a/test/deserialize/Artifacts.toml +++ b/test/deserialize/Artifacts.toml @@ -169,6 +169,15 @@ git-tree-sha1 = "377710458916cc790bb7eec00c8e3f0719680cf8" [test_globalmaxpool_precomputed] git-tree-sha1 = "6d72b58370176351d46937ca3df65ba2fd114f04" +[test_leakyrelu] +git-tree-sha1 = "07afe319b71db2cb6bc295ff9409482721473817" + +[test_leakyrelu_default] +git-tree-sha1 = "2751dbd14e5feaf6c59798e84ae9e7d9700240b6" + +[test_leakyrelu_example] +git-tree-sha1 = "b7e814cb5b5d6d538db1d6af49c02b786cb0036e" + [test_lstm_defaults] git-tree-sha1 = "c8b0d06dc9733222906bb6471c39a9c41270d149" diff --git a/test/deserialize/deserialize.jl b/test/deserialize/deserialize.jl index 68555ad..70bc47d 100644 --- a/test/deserialize/deserialize.jl +++ b/test/deserialize/deserialize.jl @@ -139,6 +139,9 @@ end (name="test_elu_default", ninputs=1, noutputs=1), (name="test_elu_example", ninputs=1, noutputs=1), (name="test_relu", ninputs=1, noutputs=1), + (name="test_leakyrelu", ninputs=1, noutputs=1), + (name="test_leakyrelu_default", ninputs=1, noutputs=1), + (name="test_leakyrelu_example", ninputs=1, noutputs=1), (name="test_selu", ninputs=1, noutputs=1), (name="test_selu_default", ninputs=1, noutputs=1), (name="test_selu_example", ninputs=1, noutputs=1), diff --git a/test/serialize/serialize.jl b/test/serialize/serialize.jl index 502f0f0..b171f30 100644 --- a/test/serialize/serialize.jl +++ b/test/serialize/serialize.jl @@ -42,6 +42,8 @@ @testset "Paramfree op $(tc.op) attrs: $(pairs(tc.attr))" for tc in ( (op=:Relu, attr = Dict(), fd=actfuns), + (op=:LeakyRelu, attr = Dict(), fd=actfuns), + (op=:LeakyRelu, attr = Dict(:alpha => 0.05f0), fd=actfuns), (op=:Elu, attr = Dict(), fd=actfuns), (op=:Elu, attr = Dict(:alpha => 0.5f0), fd=actfuns), (op=:Selu, attr = Dict(), fd=actfuns), @@ -154,6 +156,7 @@ @testset "Layer with activation function $actfun" for actfun in ( relu, + leakyrelu, elu, selu, tanh, From 1f2ea0d50920ebabc048750bf20a5fc1eee22897 Mon Sep 17 00:00:00 2001 From: Isabelle Tan Date: Sun, 9 Jul 2023 13:29:21 +0200 Subject: [PATCH 2/2] update test artifacts --- test/deserialize/Artifacts.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/deserialize/Artifacts.toml b/test/deserialize/Artifacts.toml index 3e5ed0b..94aa737 100644 --- a/test/deserialize/Artifacts.toml +++ b/test/deserialize/Artifacts.toml @@ -83,7 +83,7 @@ git-tree-sha1 = "8be97aa969ebdbe7599798d511a7790eba0697f2" git-tree-sha1 = "41eca620fb09f7d90ec9b875a80388566baadada" [test_div] -git-tree-sha1 = "430fd9135e60076904f717971bd174b0f16e1c54" +git-tree-sha1 = "57dd66f7274aac0e2a462e49dadf5a551c4e5e80" [test_dropout_default] git-tree-sha1 = "70fe420142b8d29b708578e4b6f2929e6907cb4c" @@ -185,7 +185,7 @@ git-tree-sha1 = "c8b0d06dc9733222906bb6471c39a9c41270d149" git-tree-sha1 = "19fe9305067a4225c6dd76264bb342cf966546ae" [test_matmul_2d] -git-tree-sha1 = "481de0ea5b1fb4692f10920215ce701df1b7ba09" +git-tree-sha1 = "3008bfa77da6160c406f6dae414b19861fef9f13" [test_maxpool_1d_default] git-tree-sha1 = "9b0a2b97518eb68122276b242313529582e4be95" @@ -278,10 +278,10 @@ git-tree-sha1 = "44c37442a35def50d4e1230cdcff8a986899d18d" git-tree-sha1 = "dc4a6180985e796aca6997ae79137fee6f9c05e9" [test_sigmoid] -git-tree-sha1 = "2bb16571d0809d1e6216a1b6eb5b27d332fac4f0" +git-tree-sha1 = "6f0e41cd8b1498f3c60b3d00b9558bf311217bf8" [test_sigmoid_example] -git-tree-sha1 = "710a8b85b3a7e9e301af21f8f1c0e7b087536ba5" +git-tree-sha1 = "b8b836fd3cb97d2801777c03e98d6cd41bc17b2d" [test_softmax_axis_0] git-tree-sha1 = "4f090b0b0f540b5f133176e4cf8a118c24db1886"