diff --git a/gem/flop_count.py b/gem/flop_count.py new file mode 100644 index 00000000..57577e34 --- /dev/null +++ b/gem/flop_count.py @@ -0,0 +1,197 @@ +""" +This file contains all the necessary functions to accurately count the +total number of floating point operations for a given script. +""" + +import gem.gem as gem +import gem.impero as imp +from functools import singledispatch +import numpy +import math + + +@singledispatch +def statement(tree, parameters): + raise NotImplementedError + + +@statement.register(imp.Block) +def statement_block(tree, parameters): + flops = sum(statement(child, parameters) for child in tree.children) + return flops + + +@statement.register(imp.For) +def statement_for(tree, parameters): + extent = tree.index.extent + assert extent is not None + child, = tree.children + flops = statement(child, parameters) + return flops * extent + + +@statement.register(imp.Initialise) +def statement_initialise(tree, parameters): + return 0 + + +@statement.register(imp.Accumulate) +def statement_accumulate(tree, parameters): + flops = expression_flops(tree.indexsum.children[0], parameters) + return flops + 1 + + +@statement.register(imp.Return) +def statement_return(tree, parameters): + flops = expression_flops(tree.expression, parameters) + return flops + 1 + + +@statement.register(imp.ReturnAccumulate) +def statement_returnaccumulate(tree, parameters): + flops = expression_flops(tree.indexsum.children[0], parameters) + return flops + 1 + + +@statement.register(imp.Evaluate) +def statement_evaluate(tree, parameters): + flops = expression_flops(tree.expression, parameters, top=True) + return flops + + +@singledispatch +def flops(expr, parameters): + raise NotImplementedError(f"Don't know how to count flops of {type(expr)}") + + +@flops.register(gem.Failure) +def flops_failure(expr, parameters): + raise ValueError("Not expecting a Failure node") + + +@flops.register(gem.Variable) +@flops.register(gem.Identity) +@flops.register(gem.Delta) +@flops.register(gem.Zero) +@flops.register(gem.Literal) +@flops.register(gem.Index) +@flops.register(gem.VariableIndex) +def flops_zero(expr, parameters): + # Initial set up of these Gem nodes are of 0 floating point operations. + return 0 + + +@flops.register(gem.LogicalNot) +@flops.register(gem.LogicalAnd) +@flops.register(gem.LogicalOr) +@flops.register(gem.ListTensor) +def flops_zeroplus(expr, parameters): + # These nodes contribute 0 floating point operations, but their children may not. + return 0 + sum(expression_flops(child, parameters) + for child in expr.children) + + +@flops.register(gem.Product) +def flops_product(expr, parameters): + # Multiplication by -1 is not a flop. + a, b = expr.children + if isinstance(a, gem.Literal) and a.value == -1: + return expression_flops(b, parameters) + elif isinstance(b, gem.Literal) and b.value == -1: + return expression_flops(a, parameters) + else: + return 1 + sum(expression_flops(child, parameters) + for child in expr.children) + + +@flops.register(gem.Sum) +@flops.register(gem.Division) +@flops.register(gem.Comparison) +@flops.register(gem.MathFunction) +@flops.register(gem.MinValue) +@flops.register(gem.MaxValue) +def flops_oneplus(expr, parameters): + return 1 + sum(expression_flops(child, parameters) + for child in expr.children) + + +@flops.register(gem.Power) +def flops_power(expr, parameters): + base, exponent = expr.children + base_flops = expression_flops(base, parameters) + if isinstance(exponent, gem.Literal): + exponent = exponent.value + if exponent > 0 and exponent == math.floor(exponent): + return base_flops + int(math.ceil(math.log2(exponent))) + else: + return base_flops + 5 # heuristic + else: + return base_flops + 5 # heuristic + + +@flops.register(gem.Conditional) +def flops_conditional(expr, parameters): + condition, then, else_ = (expression_flops(child, parameters) + for child in expr.children) + return condition + max(then, else_) + + +@flops.register(gem.Indexed) +@flops.register(gem.FlexiblyIndexed) +def flops_indexed(expr, parameters): + aggregate = sum(expression_flops(child, parameters) + for child in expr.children) + # Average flops per entry + return aggregate / numpy.product(expr.children[0].shape, dtype=int) + + +@flops.register(gem.IndexSum) +def flops_indexsum(expr, parameters): + raise ValueError("Not expecting IndexSum") + + +@flops.register(gem.Inverse) +def flops_inverse(expr, parameters): + n, _ = expr.shape + # 2n^3 + child flop count + return 2*n**3 + sum(expression_flops(child, parameters) + for child in expr.children) + + +@flops.register(gem.Solve) +def flops_solve(expr, parameters): + n, m = expr.shape + # 2mn + inversion cost of A + children flop count + return 2*n*m + 2*n**3 + sum(expression_flops(child, parameters) + for child in expr.children) + + +@flops.register(gem.ComponentTensor) +def flops_componenttensor(expr, parameters): + raise ValueError("Not expecting ComponentTensor") + + +def expression_flops(expression, parameters, top=False): + """An approximation to flops required for each expression. + + :arg expression: GEM expression. + :arg parameters: Useful miscellaneous information. + :arg top: are we at the root? + :returns: flop count for the expression + """ + if not top and expression in parameters.temporaries: + return 0 + else: + return flops(expression, parameters) + + +def count_flops(impero_c): + """An approximation to flops required for a scheduled impero_c tree. + + :arg impero_c: a :class:`~.Impero_C` object. + :returns: approximate flop count for the tree. + """ + try: + return statement(impero_c.tree, impero_c) + except (ValueError, NotImplementedError): + return 0 diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py new file mode 100644 index 00000000..1925fbaf --- /dev/null +++ b/tests/test_flop_count.py @@ -0,0 +1,64 @@ +import pytest +import gem.gem as gem +from gem.flop_count import count_flops +from gem.impero_utils import preprocess_gem +from gem.impero_utils import compile_gem + + +def test_count_flops(expression): + expr, expected = expression + flops = count_flops(expr) + assert flops == expected + + +@pytest.fixture(params=("expr1", "expr2", "expr3", "expr4")) +def expression(request): + if request.param == "expr1": + expr = gem.Sum(gem.Product(gem.Variable("a", ()), gem.Literal(2)), + gem.Division(gem.Literal(3), gem.Variable("b", ()))) + C = gem.Variable("C", (1,)) + i, = gem.indices(1) + Ci = C[i] + expr, = preprocess_gem([expr]) + assignments = [(Ci, expr)] + expr = compile_gem(assignments, (i,)) + # C += a*2 + 3/b + expected = 1 + 3 + elif request.param == "expr2": + expr = gem.Comparison(">=", gem.MaxValue(gem.Literal(1), gem.Literal(2)), + gem.MinValue(gem.Literal(3), gem.Literal(1))) + C = gem.Variable("C", (1,)) + i, = gem.indices(1) + Ci = C[i] + expr, = preprocess_gem([expr]) + assignments = [(Ci, expr)] + expr = compile_gem(assignments, (i,)) + # C += max(1, 2) >= min(3, 1) + expected = 1 + 3 + elif request.param == "expr3": + expr = gem.Solve(gem.Identity(3), gem.Inverse(gem.Identity(3))) + C = gem.Variable("C", (3, 3)) + i, j = gem.indices(2) + Cij = C[i, j] + expr, = preprocess_gem([expr[i, j]]) + assignments = [(Cij, expr)] + expr = compile_gem(assignments, (i, j)) + # C += solve(Id(3x3), Id(3x3)^{-1}) + expected = 9 + 18 + 54 + 54 + elif request.param == "expr4": + A = gem.Variable("A", (10, 15)) + B = gem.Variable("B", (8, 10)) + i, j, k = gem.indices(3) + Aij = A[i, j] + Bki = B[k, i] + Cjk = gem.IndexSum(Aij * Bki, (i,)) + expr = Cjk + expr, = preprocess_gem([expr]) + assignments = [(gem.Variable("C", (15, 8))[j, k], expr)] + expr = compile_gem(assignments, (j, k)) + # Cjk += \sum_i Aij * Bki + expected = 2 * 10 * 8 * 15 + + else: + raise ValueError("Unexpected expression") + return expr, expected diff --git a/tests/test_impero_loopy_flop_counts.py b/tests/test_impero_loopy_flop_counts.py new file mode 100644 index 00000000..db393203 --- /dev/null +++ b/tests/test_impero_loopy_flop_counts.py @@ -0,0 +1,66 @@ +""" +Tests impero flop counts against loopy. +""" +import pytest +import numpy +import loopy +from tsfc import compile_form +from ufl import (FiniteElement, FunctionSpace, Mesh, TestFunction, + TrialFunction, VectorElement, dx, grad, inner, + interval, triangle, quadrilateral, + TensorProductCell) + + +def count_loopy_flops(kernel): + name = kernel.name + program = kernel.ast + program = program.with_kernel( + program[name].copy( + target=loopy.CTarget(), + silenced_warnings=["insn_count_subgroups_upper_bound", + "get_x_map_guessing_subgroup_size"]) + ) + op_map = loopy.get_op_map(program + .with_entrypoints(kernel.name), + numpy_types=None, + subgroup_size=1) + return op_map.filter_by(name=['add', 'sub', 'mul', 'div', + 'func:abs'], + dtype=[float]).eval_and_sum({}) + + +@pytest.fixture(params=[interval, triangle, quadrilateral, + TensorProductCell(triangle, interval)], + ids=lambda cell: cell.cellname()) +def cell(request): + return request.param + + +@pytest.fixture(params=[{"mode": "vanilla"}, + {"mode": "spectral"}], + ids=["vanilla", "spectral"]) +def parameters(request): + return request.param + + +def test_flop_count(cell, parameters): + mesh = Mesh(VectorElement("P", cell, 1)) + loopy_flops = [] + new_flops = [] + for k in range(1, 5): + V = FunctionSpace(mesh, FiniteElement("P", cell, k)) + u = TrialFunction(V) + v = TestFunction(V) + a = inner(u, v)*dx + inner(grad(u), grad(v))*dx + kernel, = compile_form(a, prefix="form", + parameters=parameters, + coffee=False) + # Record new flops here, and compare asymptotics and + # approximate order of magnitude. + new_flops.append(kernel.flop_count) + loopy_flops.append(count_loopy_flops(kernel)) + + new_flops = numpy.asarray(new_flops) + loopy_flops = numpy.asarray(loopy_flops) + + assert all(new_flops == loopy_flops) diff --git a/tests/test_sum_factorisation.py b/tests/test_sum_factorisation.py index 9e31b125..cb6a992e 100644 --- a/tests/test_sum_factorisation.py +++ b/tests/test_sum_factorisation.py @@ -1,8 +1,6 @@ import numpy import pytest -from coffee.visitors import EstimateFlops - from ufl import (Mesh, FunctionSpace, FiniteElement, VectorElement, TestFunction, TrialFunction, TensorProductCell, EnrichedElement, HCurlElement, HDivElement, @@ -68,7 +66,8 @@ def split_vector_laplace(cell, degree): def count_flops(form): kernel, = compile_form(form, parameters=dict(mode='spectral')) - return EstimateFlops().visit(kernel.ast) + flops = kernel.flop_count + return flops @pytest.mark.parametrize(('cell', 'order'), diff --git a/tsfc/driver.py b/tsfc/driver.py index b0e46af6..a5928b70 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -17,6 +17,7 @@ import gem import gem.impero_utils as impero_utils +from gem.flop_count import count_flops import FIAT from FIAT.reference_element import TensorProductCell @@ -240,6 +241,7 @@ def compile_integral(integral_data, form_data, prefix, parameters, interface, co index_ordering = tuple(quadrature_indices) + split_argument_indices try: impero_c = impero_utils.compile_gem(assignments, index_ordering, remove_zeros=True) + flop_count = count_flops(impero_c) except impero_utils.NoopError: # No operations, construct empty kernel return builder.construct_empty_kernel(kernel_name) @@ -265,7 +267,7 @@ def name_multiindex(multiindex, name): for multiindex, name in zip(argument_multiindices, ['j', 'k']): name_multiindex(multiindex, name) - return builder.construct_kernel(kernel_name, impero_c, index_names, quad_rule) + return builder.construct_kernel(kernel_name, impero_c, index_names, quad_rule, flop_count=flop_count) def compile_expression_dual_evaluation(expression, to_element, *, diff --git a/tsfc/kernel_interface/firedrake.py b/tsfc/kernel_interface/firedrake.py index 63ada6ad..faed2803 100644 --- a/tsfc/kernel_interface/firedrake.py +++ b/tsfc/kernel_interface/firedrake.py @@ -27,7 +27,8 @@ def make_builder(*args, **kwargs): class Kernel(object): __slots__ = ("ast", "integral_type", "oriented", "subdomain_id", "domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule", - "coefficient_numbers", "name", "__weakref__") + "coefficient_numbers", "name", "__weakref__", + "flop_count") """A compiled Kernel object. :kwarg ast: The COFFEE ast for the kernel. @@ -42,11 +43,13 @@ class Kernel(object): :kwarg quadrature_rule: The finat quadrature rule used to generate this kernel :kwarg tabulations: The runtime tabulations this kernel requires :kwarg needs_cell_sizes: Does the kernel require cell sizes. + :kwarg flop_count: Estimated total flops for this kernel. """ def __init__(self, ast=None, integral_type=None, oriented=False, subdomain_id=None, domain_number=None, quadrature_rule=None, coefficient_numbers=(), - needs_cell_sizes=False): + needs_cell_sizes=False, + flop_count=0): # Defaults self.ast = ast self.integral_type = integral_type @@ -55,6 +58,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False, self.subdomain_id = subdomain_id self.coefficient_numbers = coefficient_numbers self.needs_cell_sizes = needs_cell_sizes + self.flop_count = flop_count super(Kernel, self).__init__() @@ -267,7 +271,8 @@ def register_requirements(self, ir): knl = self.kernel knl.oriented, knl.needs_cell_sizes, knl.tabulations = check_requirements(ir) - def construct_kernel(self, name, impero_c, index_names, quadrature_rule): + def construct_kernel(self, name, impero_c, index_names, quadrature_rule, + flop_count=0): """Construct a fully built :class:`Kernel`. This function contains the logic for building the argument @@ -277,6 +282,7 @@ def construct_kernel(self, name, impero_c, index_names, quadrature_rule): :arg impero_c: ImperoC tuple with Impero AST and other data :arg index_names: pre-assigned index names :arg quadrature rule: quadrature rule + :arg flop_count: Estimated total flops for this kernel. :returns: :class:`Kernel` object """ @@ -304,6 +310,7 @@ def construct_kernel(self, name, impero_c, index_names, quadrature_rule): self.kernel.quadrature_rule = quadrature_rule self.kernel.name = name self.kernel.ast = KernelBuilderBase.construct_kernel(self, name, args, body) + self.kernel.flop_count = flop_count return self.kernel def construct_empty_kernel(self, name): @@ -351,8 +358,8 @@ def prepare_coefficient(coefficient, name, scalar_type, interior_facet=False): funarg = coffee.Decl(scalar_type, coffee.Symbol(name), pointers=[("restrict",)], qualifiers=["const"]) - - expression = gem.reshape(gem.Variable(name, (None,)), + value_size = coefficient.ufl_element().value_size() + expression = gem.reshape(gem.Variable(name, (value_size,)), coefficient.ufl_shape) return funarg, expression diff --git a/tsfc/kernel_interface/firedrake_loopy.py b/tsfc/kernel_interface/firedrake_loopy.py index 1304978b..1a7bda9c 100644 --- a/tsfc/kernel_interface/firedrake_loopy.py +++ b/tsfc/kernel_interface/firedrake_loopy.py @@ -28,7 +28,8 @@ def make_builder(*args, **kwargs): class Kernel(object): __slots__ = ("ast", "integral_type", "oriented", "subdomain_id", "domain_number", "needs_cell_sizes", "tabulations", "quadrature_rule", - "coefficient_numbers", "name", "__weakref__") + "coefficient_numbers", "name", "flop_count", + "__weakref__") """A compiled Kernel object. :kwarg ast: The loopy kernel object. @@ -44,11 +45,13 @@ class Kernel(object): :kwarg tabulations: The runtime tabulations this kernel requires :kwarg needs_cell_sizes: Does the kernel require cell sizes. :kwarg name: The name of this kernel. + :kwarg flop_count: Estimated total flops for this kernel. """ def __init__(self, ast=None, integral_type=None, oriented=False, subdomain_id=None, domain_number=None, quadrature_rule=None, coefficient_numbers=(), - needs_cell_sizes=False): + needs_cell_sizes=False, + flop_count=0): # Defaults self.ast = ast self.integral_type = integral_type @@ -57,6 +60,7 @@ def __init__(self, ast=None, integral_type=None, oriented=False, self.subdomain_id = subdomain_id self.coefficient_numbers = coefficient_numbers self.needs_cell_sizes = needs_cell_sizes + self.flop_count = flop_count super(Kernel, self).__init__() @@ -275,7 +279,8 @@ def register_requirements(self, ir): knl = self.kernel knl.oriented, knl.needs_cell_sizes, knl.tabulations = check_requirements(ir) - def construct_kernel(self, name, impero_c, index_names, quadrature_rule): + def construct_kernel(self, name, impero_c, index_names, quadrature_rule, + flop_count=0): """Construct a fully built :class:`Kernel`. This function contains the logic for building the argument @@ -285,6 +290,7 @@ def construct_kernel(self, name, impero_c, index_names, quadrature_rule): :arg impero_c: ImperoC tuple with Impero AST and other data :arg index_names: pre-assigned index names :arg quadrature rule: quadrature rule + :arg flop_count: Estimated total flops for this kernel. :returns: :class:`Kernel` object """ @@ -305,6 +311,7 @@ def construct_kernel(self, name, impero_c, index_names, quadrature_rule): self.kernel.quadrature_rule = quadrature_rule self.kernel.ast = generate_loopy(impero_c, args, self.scalar_type, name, index_names) self.kernel.name = name + self.kernel.flop_count = flop_count return self.kernel def construct_empty_kernel(self, name): @@ -332,8 +339,9 @@ def prepare_coefficient(coefficient, name, scalar_type, interior_facet=False): if coefficient.ufl_element().family() == 'Real': # Constant - funarg = lp.GlobalArg(name, dtype=scalar_type, shape=(coefficient.ufl_element().value_size(),)) - expression = gem.reshape(gem.Variable(name, (None,)), + value_size = coefficient.ufl_element().value_size() + funarg = lp.GlobalArg(name, dtype=scalar_type, shape=(value_size,)) + expression = gem.reshape(gem.Variable(name, (value_size,)), coefficient.ufl_shape) return funarg, expression