Skip to content

Commit

Permalink
Add variable.Shape
Browse files Browse the repository at this point in the history
  • Loading branch information
itsubaki committed Oct 10, 2023
1 parent 1759636 commit d7121a3
Show file tree
Hide file tree
Showing 17 changed files with 26 additions and 18 deletions.
2 changes: 2 additions & 0 deletions function/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ var (
Reshape = variable.Reshape
Transpose = variable.Transpose
MatMul = variable.MatMul
Max = variable.Max
Min = variable.Min
Clip = variable.Clip
GetItem = variable.GetItem
GetItemGrad = variable.GetItemGrad
Expand Down
2 changes: 1 addition & 1 deletion function/linear.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ func (f *LinearT) Backward(gy ...*variable.Variable) []*variable.Variable {
}

// add bias
gb := SumTo(f.b.Shape()...)(gy[0])
gb := SumTo(variable.Shape(f.b)...)(gy[0])
return append(gxs, gb)
}
2 changes: 1 addition & 1 deletion function/mean_squared_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (f *MeanSquaredErrorT) Forward(x ...*variable.Variable) []*variable.Variabl
func (f *MeanSquaredErrorT) Backward(gy ...*variable.Variable) []*variable.Variable {
diff := Sub(f.x0, f.x1)
N := float64(len(diff.Data))
gyb := BroadcastTo(diff.Shape()...)(gy[0])
gyb := BroadcastTo(variable.Shape(diff)...)(gy[0])

gx0 := MulC(2.0/N, Mul(gyb, diff)) // gy * (x0 - x1) * 2/N
gx1 := Neg(gx0) // -gx0
Expand Down
2 changes: 1 addition & 1 deletion function/simple_dropout.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Dropout(ratio float64, s ...rand.Source) func(x ...*variable.Variable) *var
return x[0]
}

xs := x[0].Shape()
xs := variable.Shape(x[0])
mask := matrix.Mask(matrix.Rand(xs[0], xs[1], s...), mask(ratio))
return MulC(1.0/(1.0-ratio), Mul(x[0], variable.NewOf(mask...))) // y = x * mask / (1 - ratio)
}
Expand Down
2 changes: 1 addition & 1 deletion function/softmax_cross_entropy.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (f *SoftmaxCrossEntropyT) Forward(x ...*variable.Variable) []*variable.Vari
}

func (f *SoftmaxCrossEntropyT) Backward(gy ...*variable.Variable) []*variable.Variable {
xs := f.x.Shape()
xs := variable.Shape(f.x)
N, C := xs[0], xs[1]

y := Softmax(f.x) // y = softmax(x)
Expand Down
2 changes: 2 additions & 0 deletions numerical/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ var (
_ Func = F.Reshape(2, 2)
_ Func = F.Transpose
_ Func = F.MatMul
_ Func = F.Max
_ Func = F.Min
_ Func = F.Clip(0.0, 1.0)
_ Func = F.Linear
_ Func = F.Sigmoid
Expand Down
2 changes: 1 addition & 1 deletion variable/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type AddT struct {
}

func (f *AddT) Forward(x ...*Variable) []*Variable {
f.x0Shape, f.x1Shape = x[0].Shape(), x[1].Shape()
f.x0Shape, f.x1Shape = Shape(x[0]), Shape(x[1])

x0, x1 := matrix.Broadcast(x[0].Data, x[1].Data)
y := matrix.Add(x0, x1)
Expand Down
2 changes: 1 addition & 1 deletion variable/broadcast_to.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type BroadcastToT struct {
}

func (f *BroadcastToT) Forward(x ...*Variable) []*Variable {
f.xShape = x[0].Shape()
f.xShape = Shape(x[0])

y := matrix.BroadcastTo(f.Shape, x[0].Data)
return []*Variable{
Expand Down
2 changes: 1 addition & 1 deletion variable/div.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type DivT struct {
}

func (f *DivT) Forward(x ...*Variable) []*Variable {
f.x0Shape, f.x1Shape = x[0].Shape(), x[1].Shape()
f.x0Shape, f.x1Shape = Shape(x[0]), Shape(x[1])
f.x0, f.x1 = x[0], x[1]

x0, x1 := matrix.Broadcast(x[0].Data, x[1].Data)
Expand Down
2 changes: 1 addition & 1 deletion variable/get_item.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type GetItemT struct {
}

func (f *GetItemT) Forward(x ...*Variable) []*Variable {
f.xShape = x[0].Shape()
f.xShape = Shape(x[0])

y := make([][]float64, len(f.Slices))
for i, idx := range f.Slices {
Expand Down
12 changes: 8 additions & 4 deletions variable/max.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@ func (f *MaxT) Forward(x ...*Variable) []*Variable {
}

func (f *MaxT) Backward(gy ...*Variable) []*Variable {
ybr := matrix.BroadcastTo(f.x.Shape(), f.y.Data)
cond := NewOf(matrix.F2(f.x.Data, ybr, cond)...)
gybr := BroadcastTo(cond.Shape()...)(gy[0])
ybr := matrix.BroadcastTo(Shape(f.x), f.y.Data)
mask := mask(f.x.Data, ybr)
gybr := BroadcastTo(Shape(mask)...)(gy[0])

gx := Mul(gybr, cond)
gx := Mul(gybr, mask)
return []*Variable{
gx,
}
}

func mask(x, y [][]float64) *Variable {
return NewOf(matrix.F2(x, y, cond)...)
}

func cond(a, b float64) float64 {
if math.Abs(a-b) < 1e-13 {
return 1.0
Expand Down
2 changes: 1 addition & 1 deletion variable/mul.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type MulT struct {
}

func (f *MulT) Forward(x ...*Variable) []*Variable {
f.x0Shape, f.x1Shape = x[0].Shape(), x[1].Shape()
f.x0Shape, f.x1Shape = Shape(x[0]), Shape(x[1])
f.x0, f.x1 = x[0], x[1]

x0, x1 := matrix.Broadcast(x[0].Data, x[1].Data)
Expand Down
2 changes: 1 addition & 1 deletion variable/reshape.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type ReshapeT struct {
}

func (f *ReshapeT) Forward(x ...*Variable) []*Variable {
f.xShape = x[0].Shape()
f.xShape = Shape(x[0])

y := matrix.Reshape(f.Shape, x[0].Data)
return []*Variable{
Expand Down
2 changes: 1 addition & 1 deletion variable/sub.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type SubT struct {
}

func (f *SubT) Forward(x ...*Variable) []*Variable {
f.x0Shape, f.x1Shape = x[0].Shape(), x[1].Shape()
f.x0Shape, f.x1Shape = Shape(x[0]), Shape(x[1])

x0, x1 := matrix.Broadcast(x[0].Data, x[1].Data)
y := matrix.Sub(x0, x1)
Expand Down
2 changes: 1 addition & 1 deletion variable/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type SumT struct {
}

func (f *SumT) Forward(x ...*Variable) []*Variable {
f.xShape = x[0].Shape()
f.xShape = Shape(x[0])

y := matrix.Sum(x[0].Data)
return []*Variable{
Expand Down
2 changes: 1 addition & 1 deletion variable/sum_to.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type SumToT struct {
}

func (f *SumToT) Forward(x ...*Variable) []*Variable {
f.xShape = x[0].Shape()
f.xShape = Shape(x[0])

y := matrix.SumTo(f.Shape, x[0].Data)
return []*Variable{
Expand Down
2 changes: 1 addition & 1 deletion variable/variable.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func Randn(m, n int, s ...rand.Source) *Variable {
return &Variable{Data: matrix.Randn(m, n, s...)}
}

func (v *Variable) Shape() []int {
func Shape(v *Variable) []int {
return matrix.Shape(v.Data)
}

Expand Down

0 comments on commit d7121a3

Please sign in to comment.