Skip to content

Commit

Permalink
Update some files
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Sep 30, 2023
1 parent 377bc41 commit 2792ece
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 42 deletions.
20 changes: 10 additions & 10 deletions layer/gru.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (l *GRU) Forward(x, h matrix.Matrix, _ ...Opts) matrix.Matrix {
l.hhat = matrix.F(matrix.Dot(x, Wxh).Add(matrix.Dot(h.Mul(l.r), Whh)).Add(Bh), activation.Tanh) // hhat = tanh(x.Wxh + (h * r).Whh + bh)
l.x, l.hprev = x, h

hnext := matrix.F(l.z, oneSub).Mul(l.hprev).Add(l.z.Mul(l.hhat)) // (1 - z) * hprev + z * hhat
hnext := matrix.SubC(1, l.z).Mul(l.hprev).Add(l.z.Mul(l.hhat)) // (1 - z) * hprev + z * hhat
return hnext
}

Expand All @@ -52,17 +52,17 @@ func (l *GRU) Backward(dhnext matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
Whz, Whr, Whh := WhH[0], WhH[1], WhH[2]

// dh
dhhat := dhnext.Mul(l.z) // dhhat = dhnext * z
dhprev := dhnext.Mul(matrix.F(l.z, oneSub)) // dhprev = dhnext * (1 - z)
dhhat := dhnext.Mul(l.z) // dhhat = dhnext * z
dhprev := dhnext.Mul(matrix.SubC(1, l.z)) // dhprev = dhnext * (1 - z)

// tanh
dt := dhhat.Mul(matrix.F(l.hhat, oneSubPow2)) // dt = dhhat * (1 - hhat**2)
dbh := matrix.New(dt.SumAxis0()) // dbh = sum(dt, axis=0)
dWhh := matrix.Dot(l.r.Mul(l.hprev).T(), dt) // dWhh = (r * hprev).T.dt
dhr := matrix.Dot(dt, Whh.T()) // dhr = dt.Whh.T
dWxh := matrix.Dot(l.x.T(), dt) // dWxh = x.T.dt
dx := matrix.Dot(dt, Wxh.T()) // dx = dt.Wxh.T
dhprev = dhprev.Add(dhr.Mul(l.r)) // dhprev = dhprev + dhr * r
dt := dhhat.Mul(matrix.F(l.hhat, dtanh)) // dt = dhhat * (1 - hhat**2)
dbh := matrix.New(dt.SumAxis0()) // dbh = sum(dt, axis=0)
dWhh := matrix.Dot(l.r.Mul(l.hprev).T(), dt) // dWhh = (r * hprev).T.dt
dhr := matrix.Dot(dt, Whh.T()) // dhr = dt.Whh.T
dWxh := matrix.Dot(l.x.T(), dt) // dWxh = x.T.dt
dx := matrix.Dot(dt, Wxh.T()) // dx = dt.Wxh.T
dhprev = dhprev.Add(dhr.Mul(l.r)) // dhprev = dhprev + dhr * r

// gate(z)
dz := dhnext.Mul(l.hhat).Sub(dhnext.Mul(l.hprev)) // dz = dhnext * hhat - dhnext * hprev
Expand Down
10 changes: 5 additions & 5 deletions layer/lstm.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ func (l *LSTM) Forward(x, h, c matrix.Matrix, _ ...Opts) (matrix.Matrix, matrix.

func (l *LSTM) Backward(dhNext, dcNext matrix.Matrix) (matrix.Matrix, matrix.Matrix, matrix.Matrix) {
tanh := matrix.F(l.cNext, activation.Tanh) // tanh(cNext)
dt := matrix.F(tanh, oneSubPow2) // 1 - tanh(cNext)**2
dt := matrix.F(tanh, dtanh) // 1 - tanh(cNext)**2
ds := dcNext.Add(dhNext.Mul(l.o).Mul(dt)) // dcNext + (dhNext * o) * (1 - tanh(cNext)**2)

df := ds.Mul(l.c) // ds * cPrev
dg := ds.Mul(l.i) // ds * i
di := ds.Mul(l.g) // ds * g
do := dhNext.Mul(tanh) // dhNext * tanh(cNext)

df = df.Mul(matrix.F(l.f, dSigmoid)) // df = df * f * (1 - f)
dg = dg.Mul(matrix.F(l.g, oneSubPow2)) // dg = dg * (1 - g**2)
di = di.Mul(matrix.F(l.i, dSigmoid)) // di = di * i * (1 - i)
do = do.Mul(matrix.F(l.o, dSigmoid)) // do = do * o * (1 - o)
df = df.Mul(matrix.F(l.f, dSigmoid)) // df = df * f * (1 - f)
dg = dg.Mul(matrix.F(l.g, dtanh)) // dg = dg * (1 - g**2)
di = di.Mul(matrix.F(l.i, dSigmoid)) // di = di * i * (1 - i)
do = do.Mul(matrix.F(l.o, dSigmoid)) // do = do * o * (1 - o)

dA := matrix.HStack(df, dg, di, do) // (N, 4H)

Expand Down
10 changes: 5 additions & 5 deletions layer/rnn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ func (l *RNN) Forward(x, h matrix.Matrix, _ ...Opts) matrix.Matrix {
}

func (l *RNN) Backward(dhNext matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
dt := dhNext.Mul(matrix.F(l.hNext, oneSubPow2)) // dt = dhNext * (1 - hNext**2)
dx := matrix.Dot(dt, l.Wx.T()) // dot(dt(N, H), Wx.T(H, D)) -> dx(N, D)
dh := matrix.Dot(dt, l.Wh.T()) // dot(dt(N, H), Wh.T(H, H)) -> dh(N, H)
dt := dhNext.Mul(matrix.F(l.hNext, dtanh)) // dt = dhNext * (1 - hNext**2)
dx := matrix.Dot(dt, l.Wx.T()) // dot(dt(N, H), Wx.T(H, D)) -> dx(N, D)
dh := matrix.Dot(dt, l.Wh.T()) // dot(dt(N, H), Wh.T(H, H)) -> dh(N, H)

l.DWx = matrix.Dot(l.x.T(), dt) // dot(x.T(D, N), dt(N, H)) -> (D, H)
l.DWh = matrix.Dot(l.h.T(), dt) // dot(hPrev.T(H, N), dt(N, H)) -> (H, H)
l.DB = matrix.New(dt.SumAxis0()) // sum(dt(N, H), axis=0) -> (1, H)
return dx, dh
}

// oneSubPow2 returns 1 - a**2
func oneSubPow2(a float64) float64 { return 1 - a*a }
// dtanh returns 1 - a**2
func dtanh(y float64) float64 { return 1 - y*y }
4 changes: 1 addition & 3 deletions layer/sigmoid_with_loss.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ func (l *SigmoidWithLoss) Forward(x, t matrix.Matrix, _ ...Opts) matrix.Matrix {
l.y, l.t = matrix.F(x, activation.Sigmoid), t

// loss = Loss(y, t) + Loss(1 - y, 1 - t)
loss := Loss(l.y, l.t) + Loss(matrix.F(l.y, oneSub), matrix.F(l.t, oneSub))
loss := Loss(l.y, l.t) + Loss(matrix.SubC(1, l.y), matrix.SubC(1, l.t))
return matrix.New([]float64{loss})
}

func (l *SigmoidWithLoss) Backward(dout matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
dx := l.y.Sub(l.t).Mul(dout).MulC(1.0 / float64(len(l.t))) // (y - t) * dout / size
return dx, nil
}

func oneSub(v float64) float64 { return 1 - v }
5 changes: 5 additions & 0 deletions math/matrix/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ func (m Matrix) Broadcast(a, b int) Matrix {
return m
}

// SubC returns c - m
func SubC(c float64, m Matrix) Matrix {
return F(m, func(v float64) float64 { return c - v })
}

// Dot returns the dot product of m and n.
func Dot(m, n Matrix) Matrix {
a, b := m.Dim()
Expand Down
34 changes: 15 additions & 19 deletions math/matrix/matrix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func ExampleZero() {
// Output:
// [0 0 0]
// [0 0 0]

}

func ExampleOne() {
Expand All @@ -26,7 +25,6 @@ func ExampleOne() {
// Output:
// [1 1 1]
// [1 1 1]

}

func ExampleRand() {
Expand All @@ -41,8 +39,8 @@ func ExampleRand() {
// 2 3
// [0.6046602879796196 0.9405090880450124 0.6645600532184904]
// [0.4377141871869802 0.4246374970712657 0.6868230728671094]

}

func ExampleRandn() {
fmt.Println(matrix.Randn(2, 3).Dim())

Expand All @@ -55,7 +53,6 @@ func ExampleRandn() {
// 2 3
// [-1.233758177597947 -0.12634751070237293 -0.5209945711531503]
// [2.28571911769958 0.3228052526115799 0.5900672875996937]

}

func ExampleMask() {
Expand Down Expand Up @@ -165,7 +162,6 @@ func ExampleDot() {
// Output:
// [19 22]
// [43 50]

}

func ExampleMatrix_Dim() {
Expand All @@ -175,7 +171,6 @@ func ExampleMatrix_Dim() {
// Output:
// 0 0
// 1 3

}

func ExampleMatrix_Size() {
Expand All @@ -185,7 +180,6 @@ func ExampleMatrix_Size() {
// Output:
// 0
// 6

}

func ExampleMatrix_Add() {
Expand All @@ -206,7 +200,6 @@ func ExampleMatrix_Add() {
// Output:
// [6 8]
// [10 12]

}

func ExampleMatrix_Sub() {
Expand All @@ -227,7 +220,21 @@ func ExampleMatrix_Sub() {
// Output:
// [-4 -4]
// [-4 -4]
}

func ExampleSubC() {
A := matrix.New(
[]float64{1, 2},
[]float64{3, 4},
)

for _, r := range matrix.SubC(1, A) {
fmt.Println(r)
}

// Output:
// [0 -1]
// [-2 -3]
}

func ExampleMatrix_Mul() {
Expand All @@ -248,7 +255,6 @@ func ExampleMatrix_Mul() {
// Output:
// [5 12]
// [21 32]

}

func ExampleMatrix_Div() {
Expand All @@ -269,7 +275,6 @@ func ExampleMatrix_Div() {
// Output:
// [0.2 1]
// [3 0.5]

}

func ExampleMatrix_AddC() {
Expand All @@ -285,7 +290,6 @@ func ExampleMatrix_AddC() {
// Output:
// [3 4]
// [5 6]

}

func ExampleMatrix_MulC() {
Expand All @@ -301,7 +305,6 @@ func ExampleMatrix_MulC() {
// Output:
// [2 4]
// [6 8]

}

func ExampleMatrix_Pow2() {
Expand All @@ -317,7 +320,6 @@ func ExampleMatrix_Pow2() {
// Output:
// [1 4]
// [9 16]

}

func ExampleMatrix_Sqrt() {
Expand All @@ -333,7 +335,6 @@ func ExampleMatrix_Sqrt() {
// Output:
// [1 1.4142135623730951]
// [1.7320508075688772 2]

}

func ExampleMatrix_Abs() {
Expand All @@ -349,7 +350,6 @@ func ExampleMatrix_Abs() {
// Output:
// [1 2]
// [3 4]

}

func ExampleMatrix_Mean() {
Expand Down Expand Up @@ -557,7 +557,6 @@ func ExampleF2() {
// Output:
// [5 12]
// [21 32]

}

func ExampleF3() {
Expand All @@ -583,7 +582,6 @@ func ExampleF3() {
// Output:
// [-4 3]
// [12 23]

}

func ExamplePadding() {
Expand Down Expand Up @@ -611,7 +609,6 @@ func ExamplePadding() {
//
// [1 2]
// [3 4]

}

func ExampleReshape() {
Expand Down Expand Up @@ -677,5 +674,4 @@ func ExampleHStack() {
// Output:
// [1 2 3 7 8 9]
// [4 5 6 10 11 12]

}

0 comments on commit 2792ece

Please sign in to comment.