Skip to content

Commit

Permalink
Add LSTM model
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 21, 2023
1 parent 0f2061b commit 34f3ee4
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 16 deletions.
File renamed without changes.
File renamed without changes.
44 changes: 44 additions & 0 deletions model/lstm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package model

import (
"math/rand"

L "github.com/itsubaki/autograd/layer"
"github.com/itsubaki/autograd/variable"
)

type LSTMOpts struct {
Source rand.Source
}

type LSTM struct {
Model
}

func NewLSTM(hiddenSize, outSize int, opts ...LSTMOpts) *LSTM {
var s rand.Source
if len(opts) > 0 && opts[0].Source != nil {
s = opts[0].Source
}

return &LSTM{
Model: Model{
Layers: []L.Layer{
L.LSTM(hiddenSize, L.LSTMOpts{Source: s}),
L.Linear(outSize, L.LinearOpts{Source: s}),
},
},
}
}

func (m *LSTM) ResetState() {
m.Layers[0].(*L.LSTMT).ResetState()
}

func (m *LSTM) Forward(x *variable.Variable) *variable.Variable {
for _, l := range m.Layers {
x = l.First(x)
}

return x
}
87 changes: 87 additions & 0 deletions model/lstm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package model_test

import (
"fmt"
"math/rand"

"github.com/itsubaki/autograd/model"
"github.com/itsubaki/autograd/variable"
)

func ExampleLSTM() {
m := model.NewLSTM(2, 3)

for _, l := range m.Layers {
fmt.Printf("%T\n", l)
}

// Output:
// *layer.LSTMT
// *layer.LinearT
}

func ExampleLSTM_backward() {
m := model.NewLSTM(1, 1, model.LSTMOpts{
Source: rand.NewSource(1),
})

x := variable.New(1, 2)
y := m.Forward(x)
y.Backward()
y = m.Forward(x)
y.Backward()

for _, l := range m.Layers {
fmt.Printf("%T\n", l)
for _, p := range l.Params() {
fmt.Println(p.Name, p.Grad)
}
}

// Unordered output:
// *layer.LSTMT
// w variable([0.011226806999443534])
// w variable([0.11294101620706003])
// w variable([0.00036758770803267443])
// w variable([0.013679446277302425])
// b variable([0.03359993499751152])
// b variable([0.016635625479843978])
// b variable([0.09921510586976715])
// b variable([0.4532075079866421])
// w variable([[0.03359993499751152] [0.06719986999502305]])
// w variable([[0.09921510586976715] [0.1984302117395343]])
// w variable([[0.4532075079866421] [0.9064150159732842]])
// w variable([[0.016635625479843978] [0.033271250959687956]])
// *layer.LinearT
// b variable([2])
// w variable([0.8926291447755661])
}

func ExampleLSTM_ResetState() {
m := model.NewLSTM(1, 1)

x := variable.New(1, 2)
m.Forward(x)
m.ResetState()
m.Forward(x)

for _, p := range m.Params() {
fmt.Println(p.Name, p.Grad)
}

// Unordered output:
// w <nil>
// b <nil>
// w <nil>
// b <nil>
// w <nil>
// b <nil>
// w <nil>
// b <nil>
// w <nil>
// b <nil>
// w <nil>
// w <nil>
// w <nil>
// w <nil>
}
17 changes: 1 addition & 16 deletions model/mlp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,6 @@ func ExampleMLP() {
// *layer.LinearT
}

func ExampleMLPOpts() {
m := model.NewMLP([]int{5, 1}, model.MLPOpts{
Activation: F.ReLU,
Source: rand.NewSource(1),
})

x := variable.New(1, 2)
y := m.Forward(x)
fmt.Println(y)

// Output:
// variable([1.179297448305554])
}

func ExampleMLP_backward() {
m := model.NewMLP([]int{5, 1}, model.MLPOpts{
Activation: F.ReLU,
Expand All @@ -44,8 +30,8 @@ func ExampleMLP_backward() {

x := variable.New(1, 2)
y := m.Forward(x)

y.Backward()

for _, p := range m.Params() {
fmt.Println(p.Name, p.Grad)
}
Expand All @@ -65,7 +51,6 @@ func ExampleMLP_cleargrads() {

x := variable.New(1, 2)
y := m.Forward(x)

y.Backward()
m.Cleargrads()

Expand Down
1 change: 1 addition & 0 deletions optimizer/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

var (
_ Model = (*model.MLP)(nil)
_ Model = (*model.LSTM)(nil)
)

var (
Expand Down

0 comments on commit 34f3ee4

Please sign in to comment.