diff --git a/optim/constant.go b/optim/constant.go index 37c9a63..64a1abe 100644 --- a/optim/constant.go +++ b/optim/constant.go @@ -144,73 +144,81 @@ func (c K) Multiply(term1 interface{}, errors ...error) (Expression, error) { } // Algorithm - switch term1Converted := term1.(type) { + switch right := term1.(type) { case float64: - return c.Multiply(K(term1Converted)) + return c.Multiply(K(right)) case K: - return c * term1Converted, nil + return c * right, nil case Variable: // Algorithm - term1AsSLE := term1Converted.ToScalarLinearExpression() + term1AsSLE := right.ToScalarLinearExpression() return c.Multiply(term1AsSLE) case ScalarLinearExpr: // Scale all vectors and constants - sleOut := term1Converted.Copy() + sleOut := right.Copy() sleOut.L.ScaleVec(float64(c), &sleOut.L) - sleOut.C = term1Converted.C * float64(c) + sleOut.C = right.C * float64(c) return sleOut, nil case ScalarQuadraticExpression: // Scale all matrices and constants var sqeOut ScalarQuadraticExpression - sqeOut.Q.Scale(float64(c), &term1Converted.Q) - sqeOut.L.ScaleVec(float64(c), &term1Converted.L) - sqeOut.C = float64(c) * term1Converted.C + sqeOut.Q.Scale(float64(c), &right.Q) + sqeOut.L.ScaleVec(float64(c), &right.L) + sqeOut.C = float64(c) * right.C return sqeOut, nil case KVector: - var prod mat.VecDense = ZerosVector(term1Converted.Len()) - term1AsVecDense := mat.VecDense(term1Converted) + var prod mat.VecDense = ZerosVector(right.Len()) + term1AsVecDense := mat.VecDense(right) prod.ScaleVec(float64(c), &term1AsVecDense) return KVector(prod), nil case KVectorTranspose: - var prod mat.VecDense = ZerosVector(term1Converted.Len()) - term1AsVecDense := mat.VecDense(term1Converted) + var prod mat.VecDense = ZerosVector(right.Len()) + term1AsVecDense := mat.VecDense(right) prod.ScaleVec(float64(c), &term1AsVecDense) return KVectorTranspose(prod), nil case VarVector: - var vleOut VectorLinearExpr - vleOut.X = term1Converted.Copy() - tempIdentity := Identity(term1Converted.Len()) // Is this needed? - vleOut.L.Scale(float64(c), &tempIdentity) - vleOut.C = ZerosVector(term1Converted.Len()) - - return vleOut, nil + // VarVector is of unit length. + return ScalarLinearExpr{ + L: OnesVector(1), + X: right.Copy(), + C: 0.0, + }, nil case VarVectorTranspose: - var vleOut VectorLinearExpressionTranspose - vleOut.X = term1Converted.Copy().Transpose().(VarVector) - tempIdentity := Identity(term1Converted.Len()) // Is this needed? - vleOut.L.Scale(float64(c), &tempIdentity) - vleOut.C = ZerosVector(term1Converted.Len()) - - return vleOut, nil + if right.Len() == 1 { + rightTransposed := right.Transpose().(VarVector) + return ScalarLinearExpr{ + L: OnesVector(1), + X: rightTransposed.Copy(), + C: 0.0, + }, nil + } else { + var vleOut VectorLinearExpressionTranspose + vleOut.X = right.Copy().Transpose().(VarVector) + tempIdentity := Identity(right.Len()) // Is this needed? + vleOut.L.Scale(float64(c), &tempIdentity) + vleOut.C = ZerosVector(right.Len()) + + return vleOut, nil + } case VectorLinearExpr: var vleOut VectorLinearExpr - vleOut.L.Scale(float64(c), &term1Converted.L) - vleOut.C.ScaleVec(float64(c), &term1Converted.C) - vleOut.X = term1Converted.X.Copy() + vleOut.L.Scale(float64(c), &right.L) + vleOut.C.ScaleVec(float64(c), &right.C) + vleOut.X = right.X.Copy() return vleOut, nil case VectorLinearExpressionTranspose: var vletOut VectorLinearExpressionTranspose - vletOut.L.Scale(float64(c), &term1Converted.L) - vletOut.C.ScaleVec(float64(c), &term1Converted.C) - vletOut.X = term1Converted.X.Copy() + vletOut.L.Scale(float64(c), &right.L) + vletOut.C.ScaleVec(float64(c), &right.C) + vletOut.X = right.X.Copy() return vletOut, nil default: @@ -222,3 +230,7 @@ func (c K) Multiply(term1 interface{}, errors ...error) (Expression, error) { func (c K) Dims() []int { return []int{1, 1} // Signifies scalar } + +func (c K) Check() error { + return nil +} diff --git a/testing/optim/constant_test.go b/testing/optim/constant_test.go index efe7aa1..95f7107 100644 --- a/testing/optim/constant_test.go +++ b/testing/optim/constant_test.go @@ -911,3 +911,83 @@ func TestK_Multiply9(t *testing.T) { } } } + +/* +TestK_Multiply10 +Description: + + Tests the ability to multiply a constant with a VarVector + of non-unit length. +*/ +func TestK_Multiply10(t *testing.T) { + // Constants + m := optim.NewModel("TestK_Multiply10") + N := 3 + c1 := optim.K(3.14) + + kv2 := m.AddVariableVector(N) + + // Algorithm + _, err := c1.Multiply(kv2) + if err == nil { + t.Errorf("no error was thrown, but there should have been!") + } else { + if !strings.Contains( + err.Error(), + optim.DimensionError{ + Operation: "Multiply", + Arg1: c1, + Arg2: kv2, + }.Error(), + ) { + t.Errorf("unexpected error: %v", err) + } + } +} + +/* +TestK_Multiply11 +Description: + + Tests the ability to multiply a constant with a VarVector + of unit length. +*/ +func TestK_Multiply11(t *testing.T) { + // Constants + m := optim.NewModel("TestK_Multiply10") + N := 1 + c1 := optim.K(3.14) + + kv2 := m.AddVariableVector(N) + + // Algorithm + prod, err := c1.Multiply(kv2) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + prodAsSLE, tf := prod.(optim.ScalarLinearExpr) + if !tf { + t.Errorf("expected product to be of type ScalarLinearExpr; received %T", prod) + } + + if prodAsSLE.C != 0.0 { + t.Errorf("prod.C = %v =/= 0.0", prodAsSLE.C) + } +} + +/* +TestK_Check1 +Description: + + Tests that the Check() method returns nil as expected. +*/ +func TestK_Check1(t *testing.T) { + // Constants + k1 := optim.K(3.14) + + // Algorithm + if k1.Check() != nil { + t.Errorf("unexpected error: %v", k1.Check()) + } +}