Skip to content

Commit

Permalink
Added More Tests for Constant.Multiply(); Added More Simplifying Code…
Browse files Browse the repository at this point in the history
… to Multiply
  • Loading branch information
kwesiRutledge committed Nov 3, 2023
1 parent 69ba0ff commit b1224a4
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 33 deletions.
78 changes: 45 additions & 33 deletions optim/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,73 +144,81 @@ func (c K) Multiply(term1 interface{}, errors ...error) (Expression, error) {
}

// Algorithm
switch term1Converted := term1.(type) {
switch right := term1.(type) {
case float64:
return c.Multiply(K(term1Converted))
return c.Multiply(K(right))
case K:
return c * term1Converted, nil
return c * right, nil
case Variable:
// Algorithm
term1AsSLE := term1Converted.ToScalarLinearExpression()
term1AsSLE := right.ToScalarLinearExpression()

return c.Multiply(term1AsSLE)
case ScalarLinearExpr:
// Scale all vectors and constants
sleOut := term1Converted.Copy()
sleOut := right.Copy()
sleOut.L.ScaleVec(float64(c), &sleOut.L)
sleOut.C = term1Converted.C * float64(c)
sleOut.C = right.C * float64(c)

return sleOut, nil
case ScalarQuadraticExpression:
// Scale all matrices and constants
var sqeOut ScalarQuadraticExpression
sqeOut.Q.Scale(float64(c), &term1Converted.Q)
sqeOut.L.ScaleVec(float64(c), &term1Converted.L)
sqeOut.C = float64(c) * term1Converted.C
sqeOut.Q.Scale(float64(c), &right.Q)
sqeOut.L.ScaleVec(float64(c), &right.L)
sqeOut.C = float64(c) * right.C

return sqeOut, nil
case KVector:
var prod mat.VecDense = ZerosVector(term1Converted.Len())
term1AsVecDense := mat.VecDense(term1Converted)
var prod mat.VecDense = ZerosVector(right.Len())
term1AsVecDense := mat.VecDense(right)

prod.ScaleVec(float64(c), &term1AsVecDense)

return KVector(prod), nil
case KVectorTranspose:
var prod mat.VecDense = ZerosVector(term1Converted.Len())
term1AsVecDense := mat.VecDense(term1Converted)
var prod mat.VecDense = ZerosVector(right.Len())
term1AsVecDense := mat.VecDense(right)

prod.ScaleVec(float64(c), &term1AsVecDense)

return KVectorTranspose(prod), nil
case VarVector:
var vleOut VectorLinearExpr
vleOut.X = term1Converted.Copy()
tempIdentity := Identity(term1Converted.Len()) // Is this needed?
vleOut.L.Scale(float64(c), &tempIdentity)
vleOut.C = ZerosVector(term1Converted.Len())

return vleOut, nil
// VarVector is of unit length.
return ScalarLinearExpr{
L: OnesVector(1),
X: right.Copy(),
C: 0.0,
}, nil
case VarVectorTranspose:
var vleOut VectorLinearExpressionTranspose
vleOut.X = term1Converted.Copy().Transpose().(VarVector)
tempIdentity := Identity(term1Converted.Len()) // Is this needed?
vleOut.L.Scale(float64(c), &tempIdentity)
vleOut.C = ZerosVector(term1Converted.Len())

return vleOut, nil
if right.Len() == 1 {
rightTransposed := right.Transpose().(VarVector)
return ScalarLinearExpr{
L: OnesVector(1),
X: rightTransposed.Copy(),
C: 0.0,
}, nil
} else {
var vleOut VectorLinearExpressionTranspose
vleOut.X = right.Copy().Transpose().(VarVector)
tempIdentity := Identity(right.Len()) // Is this needed?
vleOut.L.Scale(float64(c), &tempIdentity)
vleOut.C = ZerosVector(right.Len())

return vleOut, nil
}
case VectorLinearExpr:
var vleOut VectorLinearExpr
vleOut.L.Scale(float64(c), &term1Converted.L)
vleOut.C.ScaleVec(float64(c), &term1Converted.C)
vleOut.X = term1Converted.X.Copy()
vleOut.L.Scale(float64(c), &right.L)
vleOut.C.ScaleVec(float64(c), &right.C)
vleOut.X = right.X.Copy()

return vleOut, nil
case VectorLinearExpressionTranspose:
var vletOut VectorLinearExpressionTranspose
vletOut.L.Scale(float64(c), &term1Converted.L)
vletOut.C.ScaleVec(float64(c), &term1Converted.C)
vletOut.X = term1Converted.X.Copy()
vletOut.L.Scale(float64(c), &right.L)
vletOut.C.ScaleVec(float64(c), &right.C)
vletOut.X = right.X.Copy()

return vletOut, nil
default:
Expand All @@ -222,3 +230,7 @@ func (c K) Multiply(term1 interface{}, errors ...error) (Expression, error) {
func (c K) Dims() []int {
return []int{1, 1} // Signifies scalar
}

func (c K) Check() error {
return nil
}
80 changes: 80 additions & 0 deletions testing/optim/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,83 @@ func TestK_Multiply9(t *testing.T) {
}
}
}

/*
TestK_Multiply10
Description:
Tests the ability to multiply a constant with a VarVector
of non-unit length.
*/
func TestK_Multiply10(t *testing.T) {
// Constants
m := optim.NewModel("TestK_Multiply10")
N := 3
c1 := optim.K(3.14)

kv2 := m.AddVariableVector(N)

// Algorithm
_, err := c1.Multiply(kv2)
if err == nil {
t.Errorf("no error was thrown, but there should have been!")
} else {
if !strings.Contains(
err.Error(),
optim.DimensionError{
Operation: "Multiply",
Arg1: c1,
Arg2: kv2,
}.Error(),
) {
t.Errorf("unexpected error: %v", err)
}
}
}

/*
TestK_Multiply11
Description:
Tests the ability to multiply a constant with a VarVector
of unit length.
*/
func TestK_Multiply11(t *testing.T) {
// Constants
m := optim.NewModel("TestK_Multiply10")
N := 1
c1 := optim.K(3.14)

kv2 := m.AddVariableVector(N)

// Algorithm
prod, err := c1.Multiply(kv2)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

prodAsSLE, tf := prod.(optim.ScalarLinearExpr)
if !tf {
t.Errorf("expected product to be of type ScalarLinearExpr; received %T", prod)
}

if prodAsSLE.C != 0.0 {
t.Errorf("prod.C = %v =/= 0.0", prodAsSLE.C)
}
}

/*
TestK_Check1
Description:
Tests that the Check() method returns nil as expected.
*/
func TestK_Check1(t *testing.T) {
// Constants
k1 := optim.K(3.14)

// Algorithm
if k1.Check() != nil {
t.Errorf("unexpected error: %v", k1.Check())
}
}

0 comments on commit b1224a4

Please sign in to comment.