Skip to content

Commit

Permalink
dialects: (csl-stencil) Add coefficients property
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io committed Oct 17, 2024
1 parent 27a519b commit 338d30a
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 21 deletions.
102 changes: 101 additions & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from collections.abc import Iterable, Sequence
from itertools import pairwise
from typing import cast
from typing import TypeAlias, cast

from xdsl.dialects import builtin, memref, stencil
from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnyMemRefType,
AnyMemRefTypeConstr,
AnyTensorTypeConstr,
DenseArrayBase,
Float16Type,
Float32Type,
FloatData,
IndexType,
MemRefType,
TensorType,
)
from xdsl.dialects.csl import csl
from xdsl.dialects.experimental import dmp
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import (
Expand Down Expand Up @@ -55,6 +61,36 @@
from xdsl.utils.isattr import isattr


def get_dir_and_distance(
offset: stencil.IndexAttr | tuple[int, ...],
) -> tuple[csl.Direction, int]:
"""
Given an access op, return the distance and direction, assuming as access
to a neighbour (not self) in a star-shape pattern
"""

if isinstance(offset, stencil.IndexAttr):
offset = tuple(offset)
assert len(offset) == 2, "Expecting 2-dimensional access"
assert (offset[0] == 0) != (
offset[1] == 0
), "Expecting neighbour access in a star-shape pattern"
if offset[0] < 0:
d = csl.Direction.EAST
elif offset[0] > 0:
d = csl.Direction.WEST
elif offset[1] < 0:
d = csl.Direction.NORTH
elif offset[1] > 0:
d = csl.Direction.SOUTH
else:
raise ValueError(
"Invalid offset, expecting 2-dimensional star-shape neighbor access"
)
max_distance = abs(max(offset, key=abs))
return d, max_distance


@irdl_attr_definition
class ExchangeDeclarationAttr(ParametrizedAttribute):
"""
Expand Down Expand Up @@ -152,6 +188,68 @@ def __init__(
)


CslFloat: TypeAlias = Float16Type | Float32Type


@irdl_attr_definition
class StencilCoeffsAttr(ParametrizedAttribute):
"""
This attribute represents coefficients for a stencil.
"""

name = "csl_stencil.coeffs"
north: ParameterDef[DenseArrayBase]
south: ParameterDef[DenseArrayBase]
east: ParameterDef[DenseArrayBase]
west: ParameterDef[DenseArrayBase]

def __init__(
self,
north: DenseArrayBase,
south: DenseArrayBase,
east: DenseArrayBase,
west: DenseArrayBase,
):
super().__init__([north, south, east, west])

@staticmethod
def get_empty(pattern: int, elem_t: CslFloat):
neighbours = [0] + (pattern - 1) * [1]
north = DenseArrayBase.create_dense_float(elem_t, neighbours)
south = DenseArrayBase.create_dense_float(elem_t, neighbours)
east = DenseArrayBase.create_dense_float(elem_t, neighbours)
west = DenseArrayBase.create_dense_float(elem_t, neighbours)
return StencilCoeffsAttr(north, south, east, west)

def clone_with(self, coeff: FloatData, offset: stencil.IndexAttr):
north = self.north
south = self.south
east = self.east
west = self.west

direction, dist = get_dir_and_distance(offset)
match direction:
case csl.Direction.NORTH:
north = self._rebuild_array_base(north, dist, coeff)
case csl.Direction.SOUTH:
south = self._rebuild_array_base(south, dist, coeff)
case csl.Direction.EAST:
east = self._rebuild_array_base(east, dist, coeff)
case csl.Direction.WEST:
west = self._rebuild_array_base(west, dist, coeff)

return StencilCoeffsAttr(north, south, east, west)

@staticmethod
def _rebuild_array_base(
data: DenseArrayBase, idx: int, value: FloatData
) -> DenseArrayBase:
assert isinstance(data.elt_type, AnyFloat)
lst = cast(list[FloatData], list(data.data))
lst[idx] = value
return DenseArrayBase.create_dense_float(data.elt_type, lst)


class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
Expand Down Expand Up @@ -223,6 +321,8 @@ class ApplyOp(IRDLOperation):

bounds = opt_prop_def(stencil.StencilBoundsAttr)

coeffs = opt_prop_def(StencilCoeffsAttr)

res = var_result_def(stencil.StencilTypeConstr)

traits = frozenset(
Expand Down
22 changes: 2 additions & 20 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,10 @@ def get_dir_and_distance_ops(
op: csl_stencil.AccessOp,
) -> tuple[csl.DirectionOp, arith.Constant]:
"""
Given an access op, return the distance and direction, assuming as access
Given an access op, return the distance and direction ops, assuming as access
to a neighbour (not self) in a star-shape pattern
"""

offset = tuple(op.offset)
assert len(offset) == 2, "Expecting 2-dimensional access"
assert (offset[0] == 0) != (
offset[1] == 0
), "Expecting neighbour access in a star-shape pattern"
if offset[0] < 0:
d = csl.Direction.EAST
elif offset[0] > 0:
d = csl.Direction.WEST
elif offset[1] < 0:
d = csl.Direction.NORTH
elif offset[1] > 0:
d = csl.Direction.SOUTH
else:
raise ValueError(
"Invalid offset, expecting 2-dimensional star-shape neighbor access"
)
max_distance = abs(max(offset, key=abs))
d, max_distance = csl_stencil.get_dir_and_distance(op.offset)
return csl.DirectionOp(d), arith.Constant(IntegerAttr(max_distance, 16))


Expand Down

0 comments on commit 338d30a

Please sign in to comment.