Skip to content

Commit

Permalink
Update some files
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 8, 2023
1 parent 2f65bee commit 8f82763
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
7 changes: 4 additions & 3 deletions function/mean_squared_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type MeanSquaredErrorT struct {
func (f *MeanSquaredErrorT) Forward(x ...*variable.Variable) []*variable.Variable {
f.x0, f.x1 = x[0], x[1]

diff := matrix.Sub(x[0].Data, x[1].Data)
N := float64(len(diff))
diff := matrix.Sub(x[0].Data, x[1].Data) // x0 - x1
N := float64(len(diff)) //
y := (1.0 / N) * matrix.Sum(matrix.Mul(diff, diff)) // (1/N) * sum((x0 - x1)^2)

y := (1.0 / N) * matrix.Sum(matrix.Mul(diff, diff))
return []*variable.Variable{
variable.New(y),
}
Expand All @@ -32,6 +32,7 @@ func (f *MeanSquaredErrorT) Backward(gy ...*variable.Variable) []*variable.Varia

gx0 := MulC(2.0/N, Mul(gyb, diff)) // gy * (x0 - x1) * 2/N
gx1 := Neg(gx0) // -gx0

return []*variable.Variable{
gx0,
gx1,
Expand Down
5 changes: 2 additions & 3 deletions function/softmax_cross_entropy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ func (f *SoftmaxCrossEntropyT) Backward(gy ...*variable.Variable) []*variable.Va
xs := f.x.Shape()
N, C := xs[0], xs[1]

gyn := MulC(1.0/float64(N), gy[0]) // gy = gy / N
y := Softmax(f.x) // y = softmax(x)
t := variable.NewOf(onehot(vector.Int(f.t.Data[0]), C)...) // t = onehot(t, C)
gx := Mul(Sub(y, t), gyn) // y = (y - t) * gy / N
y := Softmax(f.x) // y = softmax(x)
gx := Mul(Sub(y, t), MulC(1.0/float64(N), gy[0])) // (y - t) * gy / N

return []*variable.Variable{
gx,
Expand Down

0 comments on commit 8f82763

Please sign in to comment.