Skip to content

Commit

Permalink
use new variable visitor in get_vars_from_components rather than iden…
Browse files Browse the repository at this point in the history
…tify_variables_in_expressions
  • Loading branch information
Robbybp committed Mar 11, 2024
1 parent 313a310 commit e9e63a0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 42 deletions.
34 changes: 0 additions & 34 deletions pyomo/core/expr/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,40 +1433,6 @@ def acceptChildResult(self, node, data, child_result, child_idx):
return child_result.is_expression_type(), None


def identify_variables_in_components(components, include_fixed=True):
visitor = _StreamVariableVisitor(
include_fixed=include_fixed, descend_into_named_expressions=False
)
all_variables = []
for comp in components:
all_variables.extend(visitor.walk_expressions(comp.expr))

named_expr_set = set()
unique_named_exprs = []
for expr in visitor.named_expressions:
if id(expr) in named_expr_set:
named_expr_set.add(id(expr))
unique_named_exprs.append(expr)

while unique_named_exprs:
expr = unique_named_exprs.pop()
visitor.named_expressions.clear()
all_variables.extend(visitor.walk_expression(expr.expr))

for new_expr in visitor.named_expressions:
if id(new_expr) not in named_expr_set:
named_expr_set.add(new_expr)
unique_named_exprs.append(new_expr)

unique_vars = []
var_set = set()
for var in all_variables:
if id(var) not in var_set:
var_set.add(id(var))
unique_vars.append(var)
return unique_vars


def identify_variables(expr, include_fixed=True):
"""
A generator that yields a sequence of variables
Expand Down
31 changes: 23 additions & 8 deletions pyomo/util/vars_from_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
actually in the subtree or not.
"""
from pyomo.core import Block
import pyomo.core.expr as EXPR
from pyomo.core.expr.visitor import _StreamVariableVisitor


def get_vars_from_components(
Expand All @@ -42,17 +42,32 @@ def get_vars_from_components(
descend_into: Ctypes to descend into when finding Constraints
descent_order: Traversal strategy for finding the objects of type ctype
"""
seen = set()
visitor = _StreamVariableVisitor(
include_fixed=include_fixed, descend_into_named_expressions=False
)
variables = []
for constraint in block.component_data_objects(
ctype,
active=active,
sort=sort,
descend_into=descend_into,
descent_order=descent_order,
):
for var in EXPR.identify_variables(
constraint.expr, include_fixed=include_fixed
):
if id(var) not in seen:
seen.add(id(var))
yield var
variables.extend(visitor.walk_expression(constraint.expr))
seen_named_exprs = set()
named_expr_stack = list(visitor.named_expressions)
while named_expr_stack:
expr = named_expr_stack.pop()
# Clear visitor's named expression cache so we only identify new
# named expressions
visitor.named_expressions.clear()
variables.extend(visitor.walk_expression(expr.expr))
for new_expr in visitor.named_expressions:
if id(new_expr) not in seen_named_exprs:
seen_named_exprs.add(id(new_expr))
named_expr_stack.append(new_expr)
seen = set()
for var in variables:
if id(var) not in seen:
seen.add(id(var))
yield var

0 comments on commit e9e63a0

Please sign in to comment.