Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #93 from firedrakeproject/element-api
Browse files Browse the repository at this point in the history
* element-api:
  remove variant=('equispaced', 'spectral') tests
  test translation of element variants
  translate spectral elements to GLL/GL
  adopt UFL changes
  • Loading branch information
miklos1 committed Jan 30, 2017
2 parents ddab3cb + 6321f8f commit b870d95
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 32 deletions.
66 changes: 57 additions & 9 deletions tests/test_create_fiat_element.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import absolute_import, print_function, division
from tsfc import fiatinterface as f
import pytest

import FIAT
from FIAT.discontinuous_lagrange import HigherOrderDiscontinuousLagrange as FIAT_DiscontinuousLagrange

import ufl
from tsfc.fiatinterface import create_element, supported_elements


@pytest.fixture(params=["BDM",
Expand All @@ -22,8 +26,8 @@ def ufl_element(triangle_names):


def test_triangle_basic(ufl_element):
element = f.create_element(ufl_element)
assert isinstance(element, f.supported_elements[ufl_element.family()])
element = create_element(ufl_element)
assert isinstance(element, supported_elements[ufl_element.family()])


@pytest.fixture(params=["CG", "DG"])
Expand All @@ -46,19 +50,63 @@ def ufl_B(tensor_name):
def test_tensor_prod_simple(ufl_A, ufl_B):
tensor_ufl = ufl.TensorProductElement(ufl_A, ufl_B)

tensor = f.create_element(tensor_ufl)
A = f.create_element(ufl_A)
B = f.create_element(ufl_B)
tensor = create_element(tensor_ufl)
A = create_element(ufl_A)
B = create_element(ufl_B)

assert isinstance(tensor, f.supported_elements[tensor_ufl.family()])
assert isinstance(tensor, supported_elements[tensor_ufl.family()])

assert tensor.A is A
assert tensor.B is B


@pytest.mark.parametrize(('family', 'expected_cls'),
[('P', FIAT.Lagrange),
('DP', FIAT_DiscontinuousLagrange)])
def test_interval_variant_default(family, expected_cls):
ufl_element = ufl.FiniteElement(family, ufl.interval, 3)
assert isinstance(create_element(ufl_element), expected_cls)


@pytest.mark.parametrize(('family', 'variant', 'expected_cls'),
[('P', 'equispaced', FIAT.Lagrange),
('P', 'spectral', FIAT.GaussLobattoLegendre),
('DP', 'equispaced', FIAT_DiscontinuousLagrange),
('DP', 'spectral', FIAT.GaussLegendre)])
def test_interval_variant(family, variant, expected_cls):
ufl_element = ufl.FiniteElement(family, ufl.interval, 3, variant=variant)
assert isinstance(create_element(ufl_element), expected_cls)


def test_triangle_variant_spectral_fail():
ufl_element = ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral')
with pytest.raises(ValueError):
create_element(ufl_element)


def test_quadrilateral_variant_spectral_q():
element = create_element(ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral'))
assert isinstance(element.element.A, FIAT.GaussLobattoLegendre)
assert isinstance(element.element.B, FIAT.GaussLobattoLegendre)


def test_quadrilateral_variant_spectral_dq():
element = create_element(ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral'))
assert isinstance(element.element.A, FIAT.GaussLegendre)
assert isinstance(element.element.B, FIAT.GaussLegendre)


def test_quadrilateral_variant_spectral_rtcf():
element = create_element(ufl.FiniteElement('RTCF', ufl.quadrilateral, 2, variant='spectral'))
assert isinstance(element.element.A.A, FIAT.GaussLobattoLegendre)
assert isinstance(element.element.A.B, FIAT.GaussLegendre)
assert isinstance(element.element.B.A, FIAT.GaussLegendre)
assert isinstance(element.element.B.B, FIAT.GaussLobattoLegendre)


def test_cache_hit(ufl_element):
A = f.create_element(ufl_element)
B = f.create_element(ufl_element)
A = create_element(ufl_element)
B = create_element(ufl_element)

assert A is B

Expand Down
61 changes: 49 additions & 12 deletions tests/test_create_finat_element.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import absolute_import, print_function, division
from tsfc import finatinterface as f
import pytest

import ufl
import finat
from tsfc.finatinterface import create_element, supported_elements


@pytest.fixture(params=["BDM",
Expand All @@ -23,8 +24,8 @@ def ufl_element(triangle_names):


def test_triangle_basic(ufl_element):
element = f.create_element(ufl_element)
assert isinstance(element, f.supported_elements[ufl_element.family()])
element = create_element(ufl_element)
assert isinstance(element, supported_elements[ufl_element.family()])


@pytest.fixture
Expand All @@ -33,8 +34,8 @@ def ufl_vector_element(triangle_names):


def test_triangle_vector(ufl_element, ufl_vector_element):
scalar = f.create_element(ufl_element)
vector = f.create_element(ufl_vector_element)
scalar = create_element(ufl_element)
vector = create_element(ufl_vector_element)

assert isinstance(vector, finat.TensorFiniteElement)
assert scalar == vector.base_element
Expand All @@ -60,25 +61,61 @@ def ufl_B(tensor_name):
def test_tensor_prod_simple(ufl_A, ufl_B):
tensor_ufl = ufl.TensorProductElement(ufl_A, ufl_B)

tensor = f.create_element(tensor_ufl)
A = f.create_element(ufl_A)
B = f.create_element(ufl_B)
tensor = create_element(tensor_ufl)
A = create_element(ufl_A)
B = create_element(ufl_B)

assert isinstance(tensor, finat.TensorProductElement)

assert tensor.factors == (A, B)


@pytest.mark.parametrize(('family', 'expected_cls'),
[('P', finat.Lagrange),
('DP', finat.DiscontinuousLagrange)])
def test_interval_variant_default(family, expected_cls):
ufl_element = ufl.FiniteElement(family, ufl.interval, 3)
assert isinstance(create_element(ufl_element), expected_cls)


@pytest.mark.parametrize(('family', 'variant', 'expected_cls'),
[('P', 'equispaced', finat.Lagrange),
('P', 'spectral', finat.GaussLobattoLegendre),
('DP', 'equispaced', finat.DiscontinuousLagrange),
('DP', 'spectral', finat.GaussLegendre)])
def test_interval_variant(family, variant, expected_cls):
ufl_element = ufl.FiniteElement(family, ufl.interval, 3, variant=variant)
assert isinstance(create_element(ufl_element), expected_cls)


def test_triangle_variant_spectral_fail():
ufl_element = ufl.FiniteElement('DP', ufl.triangle, 2, variant='spectral')
with pytest.raises(ValueError):
create_element(ufl_element)


def test_quadrilateral_variant_spectral_q():
element = create_element(ufl.FiniteElement('Q', ufl.quadrilateral, 3, variant='spectral'))
assert isinstance(element.product.factors[0], finat.GaussLobattoLegendre)
assert isinstance(element.product.factors[1], finat.GaussLobattoLegendre)


def test_quadrilateral_variant_spectral_dq():
element = create_element(ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral'))
assert isinstance(element.product.factors[0], finat.GaussLegendre)
assert isinstance(element.product.factors[1], finat.GaussLegendre)


def test_cache_hit(ufl_element):
A = f.create_element(ufl_element)
B = f.create_element(ufl_element)
A = create_element(ufl_element)
B = create_element(ufl_element)

assert A is B


def test_cache_hit_vector(ufl_vector_element):
A = f.create_element(ufl_vector_element)
B = f.create_element(ufl_vector_element)
A = create_element(ufl_vector_element)
B = create_element(ufl_vector_element)

assert A is B

Expand Down
27 changes: 21 additions & 6 deletions tsfc/fiatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from FIAT.quadrature import QuadratureRule # noqa

import ufl
from ufl.algorithms.elementtransformations import reconstruct_element

from .mixedelement import MixedElement

Expand Down Expand Up @@ -186,11 +185,27 @@ def _(element, vector_is_mixed):
raise ValueError("%s is supported, but handled incorrectly" %
element.family())
# Handle quadrilateral short names like RTCF and RTCE.
element = reconstruct_element(element,
element.family(),
quad_opc,
element.degree())
element = element.reconstruct(cell=quad_tpc)
return FlattenToQuad(create_element(element, vector_is_mixed))

kind = element.variant()
if kind is None:
kind = 'equispaced' # default variant

if element.family() == "Lagrange":
if kind == 'equispaced':
lmbda = FIAT.Lagrange
elif kind == 'spectral' and element.cell().cellname() == 'interval':
lmbda = FIAT.GaussLobattoLegendre
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell()))
elif element.family() == "Discontinuous Lagrange":
if kind == 'equispaced':
lmbda = FIAT.DiscontinuousLagrange
elif kind == 'spectral' and element.cell().cellname() == 'interval':
lmbda = FIAT.GaussLegendre
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell()))
return lmbda(cell, element.degree())


Expand Down Expand Up @@ -267,7 +282,7 @@ def rec(eles):
return MixedElement(fiat_elements)


quad_opc = ufl.TensorProductCell(ufl.Cell("interval"), ufl.Cell("interval"))
quad_tpc = ufl.TensorProductCell(ufl.Cell("interval"), ufl.Cell("interval"))
_cache = weakref.WeakKeyDictionary()


Expand Down
26 changes: 21 additions & 5 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from finat.fiat_elements import FiatElementBase

import ufl
from ufl.algorithms.elementtransformations import reconstruct_element

from tsfc.fiatinterface import as_fiat_cell
from tsfc.ufl_utils import spanning_degree
Expand Down Expand Up @@ -101,11 +100,28 @@ def convert_finiteelement(element):
raise ValueError("%s is supported, but handled incorrectly" %
element.family())
# Handle quadrilateral short names like RTCF and RTCE.
element = reconstruct_element(element,
element.family(),
quad_tpc,
element.degree())
element = element.reconstruct(cell=quad_tpc)
return finat.QuadrilateralElement(create_element(element))

kind = element.variant()
if kind is None:
kind = 'equispaced' # default variant

if element.family() == "Lagrange":
if kind == 'equispaced':
lmbda = finat.Lagrange
elif kind == 'spectral' and element.cell().cellname() == 'interval':
lmbda = finat.GaussLobattoLegendre
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell()))
elif element.family() == "Discontinuous Lagrange":
kind = element.variant() or 'equispaced'
if kind == 'equispaced':
lmbda = finat.DiscontinuousLagrange
elif kind == 'spectral' and element.cell().cellname() == 'interval':
lmbda = finat.GaussLegendre
else:
raise ValueError("Variant %r not supported on %s" % (kind, element.cell()))
return lmbda(cell, element.degree())


Expand Down

0 comments on commit b870d95

Please sign in to comment.