Skip to content

Commit

Permalink
Merge branch 'main' into indsetwithin
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiirola authored Jul 9, 2024
2 parents ef55217 + b50cfac commit df73348
Show file tree
Hide file tree
Showing 7 changed files with 1,086 additions and 148 deletions.
19 changes: 13 additions & 6 deletions pyomo/core/expr/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,14 @@ def assertExpressionsEqual(test, a, b, include_named_exprs=True, places=None):
test.assertEqual(len(prefix_a), len(prefix_b))
for _a, _b in zip(prefix_a, prefix_b):
test.assertIs(_a.__class__, _b.__class__)
if places is None:
test.assertEqual(_a, _b)
# If _a is nan, check _b is nan
if _a != _a:
test.assertTrue(_b != _b)
else:
test.assertAlmostEqual(_a, _b, places=places)
if places is None:
test.assertEqual(_a, _b)
else:
test.assertAlmostEqual(_a, _b, places=places)
except (PyomoException, AssertionError):
test.fail(
f"Expressions not equal:\n\t"
Expand Down Expand Up @@ -292,10 +296,13 @@ def assertExpressionsStructurallyEqual(
for _a, _b in zip(prefix_a, prefix_b):
if _a.__class__ not in native_types and _b.__class__ not in native_types:
test.assertIs(_a.__class__, _b.__class__)
if places is None:
test.assertEqual(_a, _b)
if _a != _a:
test.assertTrue(_b != _b)
else:
test.assertAlmostEqual(_a, _b, places=places)
if places is None:
test.assertEqual(_a, _b)
else:
test.assertAlmostEqual(_a, _b, places=places)
except (PyomoException, AssertionError):
test.fail(
f"Expressions not structurally equal:\n\t"
Expand Down
122 changes: 57 additions & 65 deletions pyomo/repn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def to_expression(visitor, arg):
return arg[1].to_expression(visitor)


_exit_node_handlers = {}

#
# NEGATION handlers
#
Expand All @@ -199,11 +197,6 @@ def _handle_negation_ANY(visitor, node, arg):
return arg


_exit_node_handlers[NegationExpression] = {
None: _handle_negation_ANY,
(_CONSTANT,): _handle_negation_constant,
}

#
# PRODUCT handlers
#
Expand Down Expand Up @@ -272,16 +265,6 @@ def _handle_product_nonlinear(visitor, node, arg1, arg2):
return _GENERAL, ans


_exit_node_handlers[ProductExpression] = {
None: _handle_product_nonlinear,
(_CONSTANT, _CONSTANT): _handle_product_constant_constant,
(_CONSTANT, _LINEAR): _handle_product_constant_ANY,
(_CONSTANT, _GENERAL): _handle_product_constant_ANY,
(_LINEAR, _CONSTANT): _handle_product_ANY_constant,
(_GENERAL, _CONSTANT): _handle_product_ANY_constant,
}
_exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression]

#
# DIVISION handlers
#
Expand All @@ -302,13 +285,6 @@ def _handle_division_nonlinear(visitor, node, arg1, arg2):
return _GENERAL, ans


_exit_node_handlers[DivisionExpression] = {
None: _handle_division_nonlinear,
(_CONSTANT, _CONSTANT): _handle_division_constant_constant,
(_LINEAR, _CONSTANT): _handle_division_ANY_constant,
(_GENERAL, _CONSTANT): _handle_division_ANY_constant,
}

#
# EXPONENTIATION handlers
#
Expand Down Expand Up @@ -345,13 +321,6 @@ def _handle_pow_nonlinear(visitor, node, arg1, arg2):
return _GENERAL, ans


_exit_node_handlers[PowExpression] = {
None: _handle_pow_nonlinear,
(_CONSTANT, _CONSTANT): _handle_pow_constant_constant,
(_LINEAR, _CONSTANT): _handle_pow_ANY_constant,
(_GENERAL, _CONSTANT): _handle_pow_ANY_constant,
}

#
# ABS and UNARY handlers
#
Expand All @@ -371,12 +340,6 @@ def _handle_unary_nonlinear(visitor, node, arg):
return _GENERAL, ans


_exit_node_handlers[UnaryFunctionExpression] = {
None: _handle_unary_nonlinear,
(_CONSTANT,): _handle_unary_constant,
}
_exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression]

#
# NAMED EXPRESSION handlers
#
Expand All @@ -395,11 +358,6 @@ def _handle_named_ANY(visitor, node, arg1):
return _type, arg1.duplicate()


_exit_node_handlers[Expression] = {
None: _handle_named_ANY,
(_CONSTANT,): _handle_named_constant,
}

#
# EXPR_IF handlers
#
Expand Down Expand Up @@ -430,11 +388,6 @@ def _handle_expr_if_nonlinear(visitor, node, arg1, arg2, arg3):
return _GENERAL, ans


_exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear}
for j in (_CONSTANT, _LINEAR, _GENERAL):
for k in (_CONSTANT, _LINEAR, _GENERAL):
_exit_node_handlers[Expr_ifExpression][_CONSTANT, j, k] = _handle_expr_if_const

#
# Relational expression handlers
#
Expand Down Expand Up @@ -462,12 +415,6 @@ def _handle_equality_general(visitor, node, arg1, arg2):
return _GENERAL, ans


_exit_node_handlers[EqualityExpression] = {
None: _handle_equality_general,
(_CONSTANT, _CONSTANT): _handle_equality_const,
}


def _handle_inequality_const(visitor, node, arg1, arg2):
# It is exceptionally likely that if we get here, one of the
# arguments is an InvalidNumber
Expand All @@ -490,12 +437,6 @@ def _handle_inequality_general(visitor, node, arg1, arg2):
return _GENERAL, ans


_exit_node_handlers[InequalityExpression] = {
None: _handle_inequality_general,
(_CONSTANT, _CONSTANT): _handle_inequality_const,
}


def _handle_ranged_const(visitor, node, arg1, arg2, arg3):
# It is exceptionally likely that if we get here, one of the
# arguments is an InvalidNumber
Expand Down Expand Up @@ -523,10 +464,62 @@ def _handle_ranged_general(visitor, node, arg1, arg2, arg3):
return _GENERAL, ans


_exit_node_handlers[RangedExpression] = {
None: _handle_ranged_general,
(_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const,
}
def define_exit_node_handlers(_exit_node_handlers=None):
if _exit_node_handlers is None:
_exit_node_handlers = {}
_exit_node_handlers[NegationExpression] = {
None: _handle_negation_ANY,
(_CONSTANT,): _handle_negation_constant,
}
_exit_node_handlers[ProductExpression] = {
None: _handle_product_nonlinear,
(_CONSTANT, _CONSTANT): _handle_product_constant_constant,
(_CONSTANT, _LINEAR): _handle_product_constant_ANY,
(_CONSTANT, _GENERAL): _handle_product_constant_ANY,
(_LINEAR, _CONSTANT): _handle_product_ANY_constant,
(_GENERAL, _CONSTANT): _handle_product_ANY_constant,
}
_exit_node_handlers[MonomialTermExpression] = _exit_node_handlers[ProductExpression]
_exit_node_handlers[DivisionExpression] = {
None: _handle_division_nonlinear,
(_CONSTANT, _CONSTANT): _handle_division_constant_constant,
(_LINEAR, _CONSTANT): _handle_division_ANY_constant,
(_GENERAL, _CONSTANT): _handle_division_ANY_constant,
}
_exit_node_handlers[PowExpression] = {
None: _handle_pow_nonlinear,
(_CONSTANT, _CONSTANT): _handle_pow_constant_constant,
(_LINEAR, _CONSTANT): _handle_pow_ANY_constant,
(_GENERAL, _CONSTANT): _handle_pow_ANY_constant,
}
_exit_node_handlers[UnaryFunctionExpression] = {
None: _handle_unary_nonlinear,
(_CONSTANT,): _handle_unary_constant,
}
_exit_node_handlers[AbsExpression] = _exit_node_handlers[UnaryFunctionExpression]
_exit_node_handlers[Expression] = {
None: _handle_named_ANY,
(_CONSTANT,): _handle_named_constant,
}
_exit_node_handlers[Expr_ifExpression] = {None: _handle_expr_if_nonlinear}
for j in (_CONSTANT, _LINEAR, _GENERAL):
for k in (_CONSTANT, _LINEAR, _GENERAL):
_exit_node_handlers[Expr_ifExpression][
_CONSTANT, j, k
] = _handle_expr_if_const
_exit_node_handlers[EqualityExpression] = {
None: _handle_equality_general,
(_CONSTANT, _CONSTANT): _handle_equality_const,
}
_exit_node_handlers[InequalityExpression] = {
None: _handle_inequality_general,
(_CONSTANT, _CONSTANT): _handle_inequality_const,
}
_exit_node_handlers[RangedExpression] = {
None: _handle_ranged_general,
(_CONSTANT, _CONSTANT, _CONSTANT): _handle_ranged_const,
}
return _exit_node_handlers


class LinearBeforeChildDispatcher(BeforeChildDispatcher):
Expand Down Expand Up @@ -728,9 +721,8 @@ def _initialize_exit_node_dispatcher(exit_handlers):

class LinearRepnVisitor(StreamBasedExpressionVisitor):
Result = LinearRepn
exit_node_handlers = _exit_node_handlers
exit_node_dispatcher = ExitNodeDispatcher(
_initialize_exit_node_dispatcher(_exit_node_handlers)
_initialize_exit_node_dispatcher(define_exit_node_handlers())
)
expand_nonlinear_products = False
max_exponential_expansion = 1
Expand Down
Loading

0 comments on commit df73348

Please sign in to comment.