From b7533d82b54dd9828b8f0639758ba8c3b03c4a00 Mon Sep 17 00:00:00 2001 From: Kwesi Rutledge Date: Sat, 4 Nov 2023 14:23:00 -0400 Subject: [PATCH] Added Tests for K.Multiply() with VarVectorTranspose --- optim/constant.go | 7 ++- testing/optim/constant_test.go | 99 ++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/optim/constant.go b/optim/constant.go index 64a1abe..7885fce 100644 --- a/optim/constant.go +++ b/optim/constant.go @@ -193,11 +193,14 @@ func (c K) Multiply(term1 interface{}, errors ...error) (Expression, error) { case VarVectorTranspose: if right.Len() == 1 { rightTransposed := right.Transpose().(VarVector) - return ScalarLinearExpr{ + prod := ScalarLinearExpr{ L: OnesVector(1), X: rightTransposed.Copy(), C: 0.0, - }, nil + } + prod.L.ScaleVec(float64(c), &prod.L) + + return prod, nil } else { var vleOut VectorLinearExpressionTranspose vleOut.X = right.Copy().Transpose().(VarVector) diff --git a/testing/optim/constant_test.go b/testing/optim/constant_test.go index 95f7107..3311d42 100644 --- a/testing/optim/constant_test.go +++ b/testing/optim/constant_test.go @@ -976,6 +976,105 @@ func TestK_Multiply11(t *testing.T) { } } +/* +TestK_Multiply12 +Description: + + Tests the multiplication of a constant with a + non-unit VarVectorTranspose. +*/ +func TestK_Multiply12(t *testing.T) { + // Constants + m := optim.NewModel("TestK_Multiply11") + k1 := optim.K(3.14) + vv2 := m.AddVariableVector(21) + vvt2 := vv2.Transpose() + + // Check Multiplication result + prod, err := k1.Multiply(vvt2) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + prodAsVLET, tf := prod.(optim.VectorLinearExpressionTranspose) + if !tf { + t.Errorf( + "Expected product to be of type %v received type %T", + "VectorLinearExpressionTranspose", + prod, + ) + } + + // Check all elements of the matrix + for rowIndex := 0; rowIndex < vvt2.Len(); rowIndex++ { + for colIndex := 0; colIndex < vvt2.Len(); colIndex++ { + if (prodAsVLET.L.At(rowIndex, colIndex) != 1.0*float64(k1)) && (rowIndex == colIndex) { + t.Errorf( + " L[%v,%v] = %v =/= %v", + rowIndex, colIndex, + prodAsVLET.L.At(rowIndex, colIndex), + 1.0*float64(k1), + ) + } + + if (prodAsVLET.L.At(rowIndex, colIndex) != 0.0) && (rowIndex != colIndex) { + t.Errorf( + "L[%v,%v] = %v =/= 0.0", + rowIndex, colIndex, + prodAsVLET.L.At(rowIndex, colIndex), + ) + } + } + + } +} + +/* +TestK_Multiply13 +Description: + + Tests the multiplication of a constant with a + non-unit VarVectorTranspose. +*/ +func TestK_Multiply13(t *testing.T) { + // Constants + m := optim.NewModel("TestK_Multiply13") + k1 := optim.K(3.14) + vv2 := m.AddVariableVector(1) + vvt2 := vv2.Transpose() + + // Check Multiplication result + prod, err := k1.Multiply(vvt2) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + prodAsSLE, tf := prod.(optim.ScalarLinearExpr) + if !tf { + t.Errorf( + "Expected product to be of type %v received type %T", + "ScalarLinearExpression", + prod, + ) + } + + // Check all elements of the matrix + if prodAsSLE.L.AtVec(0) != 1.0*float64(k1) { + t.Errorf( + "L[0] = %v =/= %v", + prodAsSLE.L.AtVec(0), + 1.0*float64(k1), + ) + } + + if prodAsSLE.C != 0.0 { + t.Errorf( + "C = %v =/= 0.0", + prodAsSLE.C, + ) + } +} + /* TestK_Check1 Description: