Skip to content

Commit

Permalink
Added Tests for K.Multiply() with VarVectorTranspose
Browse files Browse the repository at this point in the history
  • Loading branch information
kwesiRutledge committed Nov 4, 2023
1 parent b1224a4 commit b7533d8
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 2 deletions.
7 changes: 5 additions & 2 deletions optim/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,14 @@ func (c K) Multiply(term1 interface{}, errors ...error) (Expression, error) {
case VarVectorTranspose:
if right.Len() == 1 {
rightTransposed := right.Transpose().(VarVector)
return ScalarLinearExpr{
prod := ScalarLinearExpr{
L: OnesVector(1),
X: rightTransposed.Copy(),
C: 0.0,
}, nil
}
prod.L.ScaleVec(float64(c), &prod.L)

return prod, nil
} else {
var vleOut VectorLinearExpressionTranspose
vleOut.X = right.Copy().Transpose().(VarVector)
Expand Down
99 changes: 99 additions & 0 deletions testing/optim/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,105 @@ func TestK_Multiply11(t *testing.T) {
}
}

/*
TestK_Multiply12
Description:
Tests the multiplication of a constant with a
non-unit VarVectorTranspose.
*/
func TestK_Multiply12(t *testing.T) {
// Constants
m := optim.NewModel("TestK_Multiply11")
k1 := optim.K(3.14)
vv2 := m.AddVariableVector(21)
vvt2 := vv2.Transpose()

// Check Multiplication result
prod, err := k1.Multiply(vvt2)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

prodAsVLET, tf := prod.(optim.VectorLinearExpressionTranspose)
if !tf {
t.Errorf(
"Expected product to be of type %v received type %T",
"VectorLinearExpressionTranspose",
prod,
)
}

// Check all elements of the matrix
for rowIndex := 0; rowIndex < vvt2.Len(); rowIndex++ {
for colIndex := 0; colIndex < vvt2.Len(); colIndex++ {
if (prodAsVLET.L.At(rowIndex, colIndex) != 1.0*float64(k1)) && (rowIndex == colIndex) {
t.Errorf(
" L[%v,%v] = %v =/= %v",
rowIndex, colIndex,
prodAsVLET.L.At(rowIndex, colIndex),
1.0*float64(k1),
)
}

if (prodAsVLET.L.At(rowIndex, colIndex) != 0.0) && (rowIndex != colIndex) {
t.Errorf(
"L[%v,%v] = %v =/= 0.0",
rowIndex, colIndex,
prodAsVLET.L.At(rowIndex, colIndex),
)
}
}

}
}

/*
TestK_Multiply13
Description:
Tests the multiplication of a constant with a
non-unit VarVectorTranspose.
*/
func TestK_Multiply13(t *testing.T) {
// Constants
m := optim.NewModel("TestK_Multiply13")
k1 := optim.K(3.14)
vv2 := m.AddVariableVector(1)
vvt2 := vv2.Transpose()

// Check Multiplication result
prod, err := k1.Multiply(vvt2)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

prodAsSLE, tf := prod.(optim.ScalarLinearExpr)
if !tf {
t.Errorf(
"Expected product to be of type %v received type %T",
"ScalarLinearExpression",
prod,
)
}

// Check all elements of the matrix
if prodAsSLE.L.AtVec(0) != 1.0*float64(k1) {
t.Errorf(
"L[0] = %v =/= %v",
prodAsSLE.L.AtVec(0),
1.0*float64(k1),
)
}

if prodAsSLE.C != 0.0 {
t.Errorf(
"C = %v =/= 0.0",
prodAsSLE.C,
)
}
}

/*
TestK_Check1
Description:
Expand Down

0 comments on commit b7533d8

Please sign in to comment.