Skip to content

Commit

Permalink
Update some files
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 21, 2023
1 parent 15a45e9 commit 4c57bd1
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions cmd/lstm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,37 @@ import (
"github.com/itsubaki/autograd/vector"
)

type SinCurve struct {
N int
type DataLoader struct {
BatchSize int
N int
Data []float64
Label []float64
iter int
}

func NewSinCurve(batchSize int) *SinCurve {
func (d *DataLoader) Next() bool {
next := (d.iter+1)*d.BatchSize < d.N
if !next {
d.iter = 0
}

return next
}

func (d *DataLoader) Batch() (*variable.Variable, *variable.Variable) {
begin, end := d.iter*d.BatchSize, (d.iter+1)*d.BatchSize
x, y := vector.Transpose(d.Data[begin:end]), vector.Transpose(d.Label[begin:end])
d.iter++
return variable.NewOf(x...), variable.NewOf(y...)
}

type SinCurve struct {
N int
Data []float64
Label []float64
}

func NewSinCurve() *SinCurve {
N, noise := 1000, 0.05

x := make([]float64, N)
Expand All @@ -35,28 +57,10 @@ func NewSinCurve(batchSize int) *SinCurve {
}

return &SinCurve{
N: N,
BatchSize: batchSize,
Data: y[:len(x)-1],
Label: y[1:],
iter: 0,
}
}

func (d *SinCurve) Next() bool {
next := (d.iter+1)*d.BatchSize < d.N
if !next {
d.iter = 0
N: N,
Data: y[:len(x)-1],
Label: y[1:],
}

return next
}

func (d *SinCurve) Read() (*variable.Variable, *variable.Variable) {
begin, end := d.iter*d.BatchSize, (d.iter+1)*d.BatchSize
x, y := vector.Transpose(d.Data[begin:end]), vector.Transpose(d.Label[begin:end])
d.iter++
return variable.NewOf(x...), variable.NewOf(y...)
}

func main() {
Expand All @@ -67,16 +71,23 @@ func main() {
flag.IntVar(&bpttLength, "bptt-length", 30, "")
flag.Parse()

dataset := NewSinCurve()
dataloader := &DataLoader{
BatchSize: batchSize,
N: dataset.N,
Data: dataset.Data,
Label: dataset.Label,
}

m := model.NewLSTM(hiddenSize, 1)
o := optimizer.SGD{LearningRate: 0.01}

dataset := NewSinCurve(batchSize)
for i := 0; i < epoch; i++ {
m.ResetState()

loss, count := variable.Const(0), 0
for dataset.Next() {
x, t := dataset.Read()
for dataloader.Next() {
x, t := dataloader.Batch()
y := m.Forward(x)
loss = F.Add(loss, F.MeanSquaredError(y, t))

Expand Down

0 comments on commit 4c57bd1

Please sign in to comment.