Skip to content

Commit

Permalink
Update some files
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 11, 2023
1 parent c6fa891 commit 858c51c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
3 changes: 1 addition & 2 deletions function/softmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ func (f *SoftmaxT) Forward(x ...*variable.Variable) []*variable.Variable {
func (f *SoftmaxT) Backward(gy ...*variable.Variable) []*variable.Variable {
gyy := Mul(gy[0], f.y) // gyy = gy * y
sum := SumTo(len(gyy.Data), 1)(gyy) // sum = sum(gx, axis=1)
gx := Sub(gyy, Mul(f.y, sum)) // gx = gyy - y * sum

return []*variable.Variable{
gx,
Sub(gyy, Mul(f.y, sum)), // gyy - y * sum
}
}
12 changes: 6 additions & 6 deletions function/softmax_cross_entropy.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ func (f *SoftmaxCrossEntropyT) Forward(x ...*variable.Variable) []*variable.Vari
}

func (f *SoftmaxCrossEntropyT) Backward(gy ...*variable.Variable) []*variable.Variable {
xs := variable.Shape(f.x)
N, C := xs[0], xs[1]
shape := variable.Shape(f.x)
N, C := shape[0], shape[1]

t := variable.NewOf(onehot(f.t.Data[0], C)...) // t = onehot(t, C)
y := Softmax(f.x) // y = softmax(x)

t := variable.NewOf(onehot(f.t.Data[0], C)...) // t = onehot(t, C)
y := Softmax(f.x) // y = softmax(x)
gx := Mul(Sub(y, t), MulC(1.0/float64(N), gy[0])) // (y - t) * gy / N
return []*variable.Variable{
gx,
Mul(Sub(y, t), MulC(1.0/float64(N), gy[0])), // (y - t) * gy / N
}
}

Expand Down
5 changes: 1 addition & 4 deletions variable/max.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ func (f *MaxT) Forward(x ...*Variable) []*Variable {

func (f *MaxT) Backward(gy ...*Variable) []*Variable {
mask := NewOf(matrix.F2(f.x.Data, f.y.Data, cond)...)
gybr := BroadcastTo(Shape(mask)...)(gy[0])

gx := Mul(gybr, mask)
return []*Variable{
gx,
Mul(gy[0], mask),
}
}

Expand Down

0 comments on commit 858c51c

Please sign in to comment.