Skip to content

Commit

Permalink
ADd MLPOpts
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 15, 2023
1 parent 3e841ef commit bc5eb37
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
18 changes: 17 additions & 1 deletion model/mlp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package model
import (
"math/rand"

F "github.com/itsubaki/autograd/function"
L "github.com/itsubaki/autograd/layer"
"github.com/itsubaki/autograd/variable"
)
Expand All @@ -12,7 +13,22 @@ type MLP struct {
Activation Activation
}

func NewMLP(outSize []int, activation Activation, s ...rand.Source) *MLP {
type MLPOpts struct {
Activation Activation
Source rand.Source
}

func NewMLP(outSize []int, opts ...MLPOpts) *MLP {
activation := F.Sigmoid
if len(opts) > 0 && opts[0].Activation != nil {
activation = opts[0].Activation
}

s := make([]rand.Source, 0)
if len(opts) > 0 && opts[0].Source != nil {
s = append(s, opts[0].Source)
}

layers := make([]*L.Layer, len(outSize))
for i := 0; i < len(outSize); i++ {
layers[i] = L.Linear(outSize[i], s...)
Expand Down
12 changes: 8 additions & 4 deletions model/mlp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
)

func ExampleMLP() {
s := rand.NewSource(1)
mlp := model.NewMLP([]int{5, 1}, F.ReLU, s)
mlp := model.NewMLP([]int{5, 1}, model.MLPOpts{
Activation: F.ReLU,
Source: rand.NewSource(1),
})

x := variable.New(1, 2)
y := mlp.Forward(x)
Expand All @@ -22,8 +24,10 @@ func ExampleMLP() {
}

func ExampleMLP_backward() {
s := rand.NewSource(1)
mlp := model.NewMLP([]int{5, 1}, F.ReLU, s)
mlp := model.NewMLP([]int{5, 1}, model.MLPOpts{
Activation: F.ReLU,
Source: rand.NewSource(1),
})

x := variable.New(1, 2)
y := mlp.Forward(x)
Expand Down

0 comments on commit bc5eb37

Please sign in to comment.