diff --git a/pyomo/gdp/plugins/multiple_bigm.py b/pyomo/gdp/plugins/multiple_bigm.py index 4dffd4e9f9a..3362276246b 100644 --- a/pyomo/gdp/plugins/multiple_bigm.py +++ b/pyomo/gdp/plugins/multiple_bigm.py @@ -12,7 +12,7 @@ import itertools import logging -from pyomo.common.collections import ComponentMap +from pyomo.common.collections import ComponentMap, ComponentSet from pyomo.common.config import ConfigDict, ConfigValue from pyomo.common.gc_manager import PauseGC from pyomo.common.modeling import unique_component_name @@ -310,9 +310,12 @@ def _transform_disjunctionData(self, obj, index, parent_disjunct, root_disjunct) arg_Ms = self._config.bigM if self._config.bigM is not None else {} + # ESJ: I am relying on the fact that the ComponentSet is going to be + # ordered here, but using a set because I will remove infeasible + # Disjuncts from it if I encounter them calculating M's. + active_disjuncts = ComponentSet(disj for disj in obj.disjuncts if disj.active) # First handle the bound constraints if we are dealing with them # separately - active_disjuncts = [disj for disj in obj.disjuncts if disj.active] transformed_constraints = set() if self._config.reduce_bound_constraints: transformed_constraints = self._transform_bound_constraints( @@ -585,7 +588,7 @@ def _calculate_missing_M_values( ): if disjunct is other_disjunct: continue - if id(other_disjunct) in scratch_blocks: + elif id(other_disjunct) in scratch_blocks: scratch = scratch_blocks[id(other_disjunct)] else: scratch = scratch_blocks[id(other_disjunct)] = Block() @@ -631,7 +634,10 @@ def _calculate_missing_M_values( scratch.obj.expr = constraint.body - constraint.lower scratch.obj.sense = minimize lower_M = self._solve_disjunct_for_M( - other_disjunct, scratch, unsuccessful_solve_msg + other_disjunct, + scratch, + unsuccessful_solve_msg, + active_disjuncts, ) if constraint.upper is not None and upper_M is None: # last resort: calculate @@ -639,7 +645,10 @@ def _calculate_missing_M_values( scratch.obj.expr = constraint.body - constraint.upper scratch.obj.sense = maximize upper_M = self._solve_disjunct_for_M( - other_disjunct, scratch, unsuccessful_solve_msg + other_disjunct, + scratch, + unsuccessful_solve_msg, + active_disjuncts, ) arg_Ms[constraint, other_disjunct] = (lower_M, upper_M) transBlock._mbm_values[constraint, other_disjunct] = (lower_M, upper_M) @@ -651,9 +660,18 @@ def _calculate_missing_M_values( return arg_Ms def _solve_disjunct_for_M( - self, other_disjunct, scratch_block, unsuccessful_solve_msg + self, other_disjunct, scratch_block, unsuccessful_solve_msg, active_disjuncts ): + if not other_disjunct.active: + # If a Disjunct is infeasible, we will discover that and deactivate + # it when we are calculating the M values. We remove that disjunct + # from active_disjuncts inside of the loop in + # _calculate_missing_M_values. So that means that we might have + # deactivated Disjuncts here that we should skip over. + return 0 + solver = self._config.solver + results = solver.solve(other_disjunct, load_solutions=False) if results.solver.termination_condition is TerminationCondition.infeasible: # [2/18/24]: TODO: After the solver rewrite is complete, we will not @@ -669,6 +687,7 @@ def _solve_disjunct_for_M( "Disjunct '%s' is infeasible, deactivating." % other_disjunct.name ) other_disjunct.deactivate() + active_disjuncts.remove(other_disjunct) M = 0 else: # This is a solver that might report diff --git a/pyomo/gdp/tests/test_mbigm.py b/pyomo/gdp/tests/test_mbigm.py index 9e82b1010f9..14a23160574 100644 --- a/pyomo/gdp/tests/test_mbigm.py +++ b/pyomo/gdp/tests/test_mbigm.py @@ -1019,39 +1019,34 @@ def test_calculate_Ms_infeasible_Disjunct(self): out.getvalue().strip(), ) - # We just fixed the infeasible by to False + # We just fixed the infeasible disjunct to False self.assertFalse(m.disjunction.disjuncts[0].active) self.assertTrue(m.disjunction.disjuncts[0].indicator_var.fixed) self.assertFalse(value(m.disjunction.disjuncts[0].indicator_var)) + # We didn't actually transform the infeasible disjunct + self.assertIsNone(m.disjunction.disjuncts[0].transformation_block) + # the remaining constraints are transformed correctly. cons = mbm.get_transformed_constraints(m.disjunction.disjuncts[1].constraint[1]) self.assertEqual(len(cons), 1) assertExpressionsEqual( self, cons[0].expr, - 21 + m.x - m.y - <= 0 * m.disjunction.disjuncts[0].binary_indicator_var - + 12.0 * m.disjunction.disjuncts[2].binary_indicator_var, + 21 + m.x - m.y <= 12.0 * m.disjunction.disjuncts[2].binary_indicator_var, ) cons = mbm.get_transformed_constraints(m.disjunction.disjuncts[2].constraint[1]) self.assertEqual(len(cons), 2) - print(cons[0].expr) - print(cons[1].expr) assertExpressionsEqual( self, cons[0].expr, - 0.0 * m.disjunction_disjuncts[0].binary_indicator_var - - 12.0 * m.disjunction_disjuncts[1].binary_indicator_var - <= m.x - (m.y - 9), + -12.0 * m.disjunction_disjuncts[1].binary_indicator_var <= m.x - (m.y - 9), ) assertExpressionsEqual( self, cons[1].expr, - m.x - (m.y - 9) - <= 0.0 * m.disjunction_disjuncts[0].binary_indicator_var - - 12.0 * m.disjunction_disjuncts[1].binary_indicator_var, + m.x - (m.y - 9) <= -12.0 * m.disjunction_disjuncts[1].binary_indicator_var, ) @unittest.skipUnless(