Skip to content

Commit

Permalink
Cache basis evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 11, 2024
1 parent 6a20368 commit aab3455
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
11 changes: 9 additions & 2 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import collections
import itertools
from functools import singledispatch
from functools import cached_property, singledispatch

import gem
import numpy
Expand All @@ -18,7 +18,6 @@
from gem.node import traversal
from gem.optimise import constant_fold_zero, ffc_rounding
from gem.unconcatenate import unconcatenate
from gem.utils import cached_property
from ufl.classes import (Argument, CellCoordinate, CellEdgeVectors,
CellFacetJacobian, CellOrientation, CellOrigin,
CellVertices, CellVolume, Coefficient, FacetArea,
Expand All @@ -42,6 +41,8 @@
TSFCConstantMixin, entity_avg, one_times,
preprocess_expression, simplify_abs)

from pyop2.caching import serial_cache


class ContextBase(ProxyKernelInterface):
"""Common UFL -> GEM translation context."""
Expand Down Expand Up @@ -296,6 +297,12 @@ def point_expr(self):
def weight_expr(self):
return self.quadrature_rule.weight_expression

def make_basis_evaluation_key(self, finat_element, mt, entity_id):
domain = extract_unique_domain(mt.terminal)
restriction = mt.restriction
return (self, finat_element, mt.local_derivatives, domain, restriction, entity_id)

@serial_cache(hashkey=make_basis_evaluation_key)
def basis_evaluation(self, finat_element, mt, entity_id):
return finat_element.basis_evaluation(mt.local_derivatives,
self.point_set,
Expand Down
3 changes: 1 addition & 2 deletions tsfc/kernel_interface/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import operator
import string
from functools import reduce
from functools import cached_property, reduce
from itertools import chain, product

import gem
Expand All @@ -13,7 +13,6 @@
from gem.node import traversal
from gem.optimise import constant_fold_zero
from gem.optimise import remove_componenttensors as prune
from gem.utils import cached_property
from numpy import asarray
from tsfc import fem, ufl_utils
from tsfc.finatinterface import as_fiat_cell, create_element
Expand Down

0 comments on commit aab3455

Please sign in to comment.