Skip to content

Commit

Permalink
Update some files
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 16, 2023
1 parent 1fc5f69 commit 4144e34
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
6 changes: 3 additions & 3 deletions autograd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,11 @@ func Example_mlp() {
}

x := variable.Rand(100, 1, s)
y := variable.Rand(100, 1, s)
t := variable.Rand(100, 1, s)

for i := 0; i < 100; i++ {
yPred := m.Forward(x)
loss := F.MeanSquaredError(y, yPred)
y := m.Forward(x)
loss := F.MeanSquaredError(y, t)

m.Cleargrads()
loss.Backward()
Expand Down
14 changes: 5 additions & 9 deletions model/mlp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import (
"github.com/itsubaki/autograd/variable"
)

type MLP struct {
type MLPOpts struct {
Activation Activation
Model
Source rand.Source
}

type MLPOpts struct {
type MLP struct {
Activation Activation
Source rand.Source
*Model
}

func NewMLP(outSize []int, opts ...MLPOpts) *MLP {
Expand All @@ -37,11 +37,7 @@ func NewMLP(outSize []int, opts ...MLPOpts) *MLP {

return &MLP{
Activation: activation,
Model: Model{
Layer: L.Layer{
Layers: layers,
},
},
Model: NewModel(layers),
}
}

Expand Down
8 changes: 8 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ type Model struct {
L.Layer
}

func NewModel(layers []*L.Layer) *Model {
return &Model{
Layer: L.Layer{
Layers: layers,
},
}
}

func (m Model) graph(y *variable.Variable, opt ...dot.Opt) []string {
out := make([]string, 0)
for _, txt := range dot.Graph(y, opt...) {
Expand Down

0 comments on commit 4144e34

Please sign in to comment.