Skip to content

Commit

Permalink
Merge pull request #59 from firedrakeproject/py3-compat
Browse files Browse the repository at this point in the history
Make GEM and TSFC mostly Python 3 compatible
  • Loading branch information
miklos1 authored Sep 26, 2016
2 parents 5061555 + 8535862 commit c69e5d2
Show file tree
Hide file tree
Showing 32 changed files with 110 additions and 64 deletions.
12 changes: 11 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
notifications:
slack:
secure: ZaeRnZLbnqMqFoH10X6U9ylAM4f9Up3ebgA097FVy71VD8hjMiw3lYZb7p/WxxsaQAplOykLQ8PUeuV7rOhotzT23xYijh3TpDrmxs+TnZ6HhHnILvV85YNfLyrvniPb6kHQHVcMTDPvIVqR9MP6gZqxZlnpqUttzYgr2rCv0vl98XE/TeY+IRUldsGEdZ84WTrWsNCb/LqCyM3+FJCHaXRZcmokHJsdStZNWpC4S9MfiDzd9DzH4pg4PT4xKLIRCtguhxnc/or56LYCnTV4PRRk4rdVaHOHRoN/0EGKOG9eWTInhDY33me7cCEAtqGsngBzpehQXDfd7YI6lnWriGpYRk+BmZaRrcuA40QEQa5z0WoHA/m7ev+3S6PGniI+ZZg2OuAw2gUgRIHfqwQ/oVREvVzR0HM9yia5JkyH9y6HtbCB9GBpt8mB+kIWQ8BIhk4sdHnKbUuEumjr9msnLQwqpI4jLoN0Ap+ZNnwoK8Lzjf/qGlcSsmkCQ2H3aI0s+q7mqHg9ZqDn0cph8qAkdCpnedHYk/itK/tPssN0O9jgOgqseVV8p2BWxR39LnHEE9F1ghTtHuW1wmBfokEa9JD6OraG/jG3Og9lqulnsNNk3bHf4digrVaHKQkprI0jeAfO0QcyuMGuAsvuY++CwzVC7SjRf02S555E9jFqnek=

language: python
python:
- "2.7"
- "3.5"

env:
- NUMPY_VERSION="==1.9.0"
- NUMPY_VERSION=">=1.10"

matrix:
exclude:
- python: "3.5"
include:
- python: "3.5"
env: NUMPY_VERSION=">=1.10"

before_install:
- pip install -r requirements.txt
- pip install flake8
- pip install flake8-future-import
- pip install --upgrade pytest
- pip install --upgrade "numpy$NUMPY_VERSION"

install:
- python setup.py install

script:
- py.test tests
- flake8 .
- py.test tests
2 changes: 1 addition & 1 deletion gem/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import
from __future__ import absolute_import, print_function, division

from gem.gem import * # noqa
42 changes: 27 additions & 15 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
indices.
"""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six import with_metaclass

from abc import ABCMeta
from itertools import chain
import numpy
from numpy import asarray, unique
from operator import attrgetter

import numpy
from numpy import asarray

from gem.node import Node as NodeBase


Expand All @@ -46,17 +48,15 @@ def __call__(self, *args, **kwargs):

# Set free_indices if not set already
if not hasattr(obj, 'free_indices'):
cfi = list(chain(*[c.free_indices for c in obj.children]))
obj.free_indices = tuple(unique(cfi))
obj.free_indices = unique(chain(*[c.free_indices
for c in obj.children]))

return obj


class Node(NodeBase):
class Node(with_metaclass(NodeMeta, NodeBase)):
"""Abstract GEM node class."""

__metaclass__ = NodeMeta

__slots__ = ('free_indices')

def is_equal(self, other):
Expand Down Expand Up @@ -339,10 +339,9 @@ def __init__(self, condition, then, else_):
self.shape = then.shape


class IndexBase(object):
class IndexBase(with_metaclass(ABCMeta)):
"""Abstract base class for indices."""

__metaclass__ = ABCMeta
pass

IndexBase.register(int)

Expand Down Expand Up @@ -379,6 +378,10 @@ def __repr__(self):
return "Index(%r)" % self.count
return "Index(%r)" % self.name

def __lt__(self, other):
# Allow sorting of free indices in Python 3
return id(self) < id(other)


class VariableIndex(IndexBase):
"""An index that is constant during a single execution of the
Expand Down Expand Up @@ -432,7 +435,7 @@ def __new__(cls, aggregate, multiindex):
self.multiindex = multiindex

new_indices = tuple(i for i in multiindex if isinstance(i, Index))
self.free_indices = tuple(unique(aggregate.free_indices + new_indices))
self.free_indices = unique(aggregate.free_indices + new_indices)

return self

Expand Down Expand Up @@ -486,7 +489,7 @@ def __init__(self, variable, dim2idxs):

self.children = (variable,)
self.dim2idxs = dim2idxs
self.free_indices = tuple(unique(indices))
self.free_indices = unique(indices)


class ComponentTensor(Node):
Expand Down Expand Up @@ -515,7 +518,7 @@ def __new__(cls, expression, multiindex):

# Collect free indices
assert set(multiindex) <= set(expression.free_indices)
self.free_indices = tuple(unique(list(set(expression.free_indices) - set(multiindex))))
self.free_indices = unique(set(expression.free_indices) - set(multiindex))

return self

Expand All @@ -540,7 +543,7 @@ def __new__(cls, summand, index):

# Collect shape and free indices
assert index in summand.free_indices
self.free_indices = tuple(unique(list(set(summand.free_indices) - {index})))
self.free_indices = unique(set(summand.free_indices) - {index})

return self

Expand Down Expand Up @@ -599,6 +602,15 @@ def get_hash(self):
return hash((type(self), self.shape, self.children))


def unique(indices):
"""Sorts free indices and eliminates duplicates.
:arg indices: iterable of indices
:returns: sorted tuple of unique free indices
"""
return tuple(sorted(set(indices), key=id))


def partial_indexed(tensor, indices):
"""Generalised indexing into a tensor. The number of indices may
be less than or equal to the rank of the tensor, so the result may
Expand Down
7 changes: 3 additions & 4 deletions gem/impero.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
(Command?) after clicking on them.
"""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six import with_metaclass

from abc import ABCMeta, abstractmethod

Expand All @@ -23,11 +24,9 @@ class Node(NodeBase):
__slots__ = ()


class Terminal(Node):
class Terminal(with_metaclass(ABCMeta, Node)):
"""Abstract class for terminal Impero nodes"""

__metaclass__ = ABCMeta

__slots__ = ()

children = ()
Expand Down
14 changes: 9 additions & 5 deletions gem/impero_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
C code or a COFFEE AST.
"""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six.moves import zip

import collections
import itertools
Expand Down Expand Up @@ -57,11 +58,14 @@ def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=Fal
indices = []
for node in traversal(expressions):
if isinstance(node, gem.Indexed):
indices.extend(node.multiindex)
for index in node.multiindex:
if isinstance(index, gem.Index):
indices.append(index)
elif isinstance(node, gem.FlexiblyIndexed):
for offset, idxs in node.dim2idxs:
for index, stride in idxs:
indices.append(index)
if isinstance(index, gem.Index):
indices.append(index)
# The next two lines remove duplicate elements from the list, but
# preserve the ordering, i.e. all elements will appear only once,
# in the order of their first occurance in the original list.
Expand All @@ -75,7 +79,7 @@ def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=Fal
get_indices = lambda expr: apply_ordering(expr.free_indices)

# Build operation ordering
ops = scheduling.emit_operations(zip(return_variables, expressions), get_indices)
ops = scheduling.emit_operations(list(zip(return_variables, expressions)), get_indices)

# Empty kernel
if len(ops) == 0:
Expand Down Expand Up @@ -177,7 +181,7 @@ def make_loop_tree(ops, get_indices, level=0):
else:
statements.extend(op_group)
# Remove no-op terminals from the tree
statements = filter(lambda s: not isinstance(s, imp.Noop), statements)
statements = [s for s in statements if not isinstance(s, imp.Noop)]
return imp.Block(statements)


Expand Down
5 changes: 3 additions & 2 deletions gem/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
An interpreter for GEM trees.
"""
from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six.moves import map

import numpy
import operator
Expand Down Expand Up @@ -302,4 +303,4 @@ def evaluate(expressions, bindings=None):
exprs = (expressions, )
mapper = node.Memoizer(_evaluate)
mapper.bindings = bindings if bindings is not None else {}
return map(mapper, exprs)
return list(map(mapper, exprs))
5 changes: 3 additions & 2 deletions gem/node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Generic abstract node class and utility functions for creating
expression DAG languages."""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six.moves import map

import collections

Expand Down Expand Up @@ -201,7 +202,7 @@ def __call__(self, node, arg):

def reuse_if_untouched(node, self):
"""Reuse if untouched recipe"""
new_children = map(self, node.children)
new_children = list(map(self, node.children))
if all(nc == c for nc, c in zip(new_children, node.children)):
return node
else:
Expand Down
6 changes: 4 additions & 2 deletions gem/optimise.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""A set of routines implementing various transformations on GEM
expressions."""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six.moves import map

from functools import reduce
from singledispatch import singledispatch

from gem.node import Memoizer, MemoizerArg, reuse_if_untouched, reuse_if_untouched_arg
Expand Down Expand Up @@ -111,4 +113,4 @@ def unroll_indexsum(expressions, max_extent):
"""
mapper = Memoizer(_unroll_indexsum)
mapper.max_extent = max_extent
return map(mapper, expressions)
return list(map(mapper, expressions))
2 changes: 1 addition & 1 deletion gem/scheduling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Schedules operations to evaluate a multi-root expression DAG,
forming an ordered list of Impero terminals."""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division

import collections
import functools
Expand Down
1 change: 1 addition & 0 deletions requirements-ext.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
numpy
singledispatch
six
8 changes: 6 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[flake8]
ignore = E501,E226,E731
exclude = .git,__pycache__
ignore =
E501,E226,E731,
FI14,FI54,
FI50,FI51,FI53
exclude = .git,__pycache__,build,dist
min-version = 2.7
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import print_function, division, absolute_import

from distutils.core import setup

version = "0.0.1"
Expand Down
2 changes: 2 additions & 0 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import absolute_import, print_function, division

import pytest

from gem import impero_utils
Expand Down
1 change: 1 addition & 0 deletions tests/test_create_element.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import absolute_import, print_function, division
from tsfc import fiatinterface as f
import pytest
import ufl
Expand Down
1 change: 1 addition & 0 deletions tests/test_idempotency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import absolute_import, print_function, division
import ufl
from tsfc import compile_form
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tsfc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import
from __future__ import absolute_import, print_function, division

from tsfc.driver import compile_form # noqa
3 changes: 2 additions & 1 deletion tsfc/coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
This is the final stage of code generation in TSFC."""

from __future__ import absolute_import
from __future__ import absolute_import, print_function, division

from collections import defaultdict
from functools import reduce
from math import isnan
import itertools

Expand Down
1 change: 1 addition & 0 deletions tsfc/compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Backwards compatibility for some functionality.
"""
from __future__ import absolute_import, print_function, division
import numpy
from distutils.version import StrictVersion

Expand Down
2 changes: 1 addition & 1 deletion tsfc/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
from __future__ import absolute_import, print_function, division

import numpy

Expand Down
6 changes: 4 additions & 2 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import absolute_import
from __future__ import absolute_import, print_function, division
from six.moves import range

import collections
import time
from functools import reduce

from ufl.classes import Form
from ufl.log import GREEN
Expand Down Expand Up @@ -277,6 +279,6 @@ def lower_integral_type(fiat_cell, integral_type):
elif integral_type == 'exterior_facet_top':
entity_ids = [1]
else:
entity_ids = range(len(fiat_cell.get_topology()[integration_dim]))
entity_ids = list(range(len(fiat_cell.get_topology()[integration_dim])))

return integration_dim, entity_ids
Loading

0 comments on commit c69e5d2

Please sign in to comment.