Skip to content

Commit

Permalink
Merge pull request #92 from firedrakeproject/impero-cleanup
Browse files Browse the repository at this point in the history
Minor simplications in impero_utils.py and remove TODO.md
  • Loading branch information
miklos1 authored Dec 13, 2016
2 parents ec0432a + 166b608 commit 3aadebe
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 36 deletions.
12 changes: 0 additions & 12 deletions TODO.md

This file was deleted.

47 changes: 24 additions & 23 deletions gem/impero_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=Fal
tree = make_loop_tree(ops, get_indices)

# Collect temporaries
temporaries = collect_temporaries(ops)
temporaries = collect_temporaries(tree)

# Determine declarations
declare, indices = place_declarations(ops, tree, temporaries, get_indices)
declare, indices = place_declarations(tree, temporaries, get_indices)

# Prepare ImperoC (Impero AST + other data for code generation)
return ImperoC(tree, temporaries, declare, indices)
Expand Down Expand Up @@ -147,18 +147,18 @@ def inline_temporaries(expressions, ops):
return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)]


def collect_temporaries(ops):
def collect_temporaries(tree):
"""Collects GEM expressions to assign to temporaries from a list
of Impero terminals."""
result = []
for op in ops:
for node in traversal((tree,)):
# IndexSum temporaries should be added either at Initialise or
# at Accumulate. The difference is only in ordering
# (numbering). We chose Accumulate here.
if isinstance(op, imp.Accumulate):
result.append(op.indexsum)
elif isinstance(op, imp.Evaluate):
result.append(op.expression)
if isinstance(node, imp.Accumulate):
result.append(node.indexsum)
elif isinstance(node, imp.Evaluate):
result.append(node.expression)
return result


Expand All @@ -185,10 +185,9 @@ def make_loop_tree(ops, get_indices, level=0):
return imp.Block(statements)


def place_declarations(ops, tree, temporaries, get_indices):
def place_declarations(tree, temporaries, get_indices):
"""Determines where and how to declare temporaries for an Impero AST.
:arg ops: terminals of ``tree``
:arg tree: Impero AST to determine the declarations for
:arg temporaries: list of GEM expressions which are assigned to
temporaries
Expand All @@ -200,8 +199,9 @@ def place_declarations(ops, tree, temporaries, get_indices):

# Collect the total number of temporary references
total_refcount = collections.Counter()
for op in ops:
total_refcount.update(temp_refcount(temporaries_set, op))
for node in traversal((tree,)):
if isinstance(node, imp.Terminal):
total_refcount.update(temp_refcount(temporaries_set, node))
assert temporaries_set == set(total_refcount)

# Result
Expand Down Expand Up @@ -264,17 +264,18 @@ def recurse_block(expr, loop_indices):

# Set in ``declare`` for Impero terminals whether they should
# declare the temporary that they are writing to.
for op in ops:
declare[op] = False
if isinstance(op, imp.Evaluate):
e = op.expression
elif isinstance(op, imp.Initialise):
e = op.indexsum
else:
continue

if len(indices[e]) == 0:
declare[op] = True
for node in traversal((tree,)):
if isinstance(node, imp.Terminal):
declare[node] = False
if isinstance(node, imp.Evaluate):
e = node.expression
elif isinstance(node, imp.Initialise):
e = node.indexsum
else:
continue

if len(indices[e]) == 0:
declare[node] = True

return declare, indices

Expand Down
2 changes: 1 addition & 1 deletion gem/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def traversal(expression_dags):
while lifo:
node = lifo.pop()
yield node
for child in node.children:
for child in reversed(node.children):
if child not in seen:
seen.add(child)
lifo.append(child)
Expand Down

0 comments on commit 3aadebe

Please sign in to comment.