Skip to content

Commit

Permalink
dialects: (builtin) make DenseIntOrFPElementsAttr a TypedAttribute (#…
Browse files Browse the repository at this point in the history
…3509)

A simpler version of #3492 which does not make
`DenseIntOrFPElementsAttr` generic.
  • Loading branch information
alexarice authored Nov 22, 2024
1 parent 4c530f8 commit 018d17f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 59 deletions.
60 changes: 57 additions & 3 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
SymbolTable,
)
from xdsl.utils.exceptions import DiagnosticException, VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

if TYPE_CHECKING:
Expand Down Expand Up @@ -1715,11 +1716,11 @@ def get_element_type(self) -> _UnrankedMemrefTypeElems:
VectorType[AttributeCovT] | TensorType[AttributeCovT] | MemRefType[AttributeCovT]
)

AnyDenseElement: TypeAlias = IntegerType | IndexType | AnyFloat


@irdl_attr_definition
class DenseIntOrFPElementsAttr(
ParametrizedAttribute, ContainerType[IntegerType | IndexType | AnyFloat]
):
class DenseIntOrFPElementsAttr(TypedAttribute, ContainerType[AnyDenseElement]):
name = "dense"
type: ParameterDef[
RankedStructure[IntegerType]
Expand Down Expand Up @@ -1871,6 +1872,59 @@ def tensor_from_list(
t = TensorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)

@staticmethod
def parse_with_type(parser: AttrParser, type: Attribute) -> TypedAttribute:
assert isa(type, RankedStructure[AnyDenseElement])
return parser.parse_dense_int_or_fp_elements_attr(type)

@staticmethod
def _print_one_elem(val: Attribute, printer: Printer):
if isinstance(val, IntegerAttr):
printer.print_string(f"{val.value.data}")
elif isinstance(val, FloatAttr):
printer.print_float(cast(AnyFloatAttr, val))
else:
raise Exception(
"unexpected attribute type "
"in DenseIntOrFPElementsAttr: "
f"{type(val)}"
)

@staticmethod
def _print_dense_list(
array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
shape: Sequence[int],
printer: Printer,
):
printer.print_string("[")
if len(shape) > 1:
k = len(array) // shape[0]
printer.print_list(
(array[i : i + k] for i in range(0, len(array), k)),
lambda subarray: DenseIntOrFPElementsAttr._print_dense_list(
subarray, shape[1:], printer
),
)
else:
printer.print_list(
array,
lambda val: DenseIntOrFPElementsAttr._print_one_elem(val, printer),
)
printer.print_string("]")

def print_without_type(self, printer: Printer):
printer.print_string("dense<")
data = self.data.data
shape = self.get_shape() if self.shape_is_complete else (len(data),)
assert shape is not None, "If shape is complete, then it cannot be None"
if len(data) == 0:
pass
elif data.count(data[0]) == len(data):
DenseIntOrFPElementsAttr._print_one_elem(data[0], printer)
else:
DenseIntOrFPElementsAttr._print_dense_list(data, shape, printer)
printer.print_string(">")


Builtin = Dialect(
"builtin",
Expand Down
19 changes: 11 additions & 8 deletions xdsl/parser/attribute_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AffineMapAttr,
AffineSetAttr,
AnyArrayAttr,
AnyDenseElement,
AnyFloat,
AnyFloatAttr,
AnyFloatConstr,
Expand Down Expand Up @@ -698,11 +699,7 @@ def _parse_optional_builtin_parametrized_attr(self) -> Attribute | None:
def _parse_builtin_dense_attr_hex(
self,
hex_string: str,
type: (
RankedStructure[IntegerType]
| RankedStructure[IndexType]
| RankedStructure[AnyFloat]
),
type: RankedStructure[AnyDenseElement],
) -> tuple[list[int] | list[float], list[int]]:
"""
Parse a hex string literal e.g. dense<"0x82F5AB00">, and returns its flattened data
Expand Down Expand Up @@ -795,7 +792,9 @@ def _parse_dense_literal_type(
self.raise_error("Dense literal attribute should have a static shape.")
return type

def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr:
def parse_dense_int_or_fp_elements_attr(
self, type: RankedStructure[AnyDenseElement] | None
) -> DenseIntOrFPElementsAttr:
dense_contents: (
tuple[list[AttrParser._TensorLiteralElement], list[int]] | str | None
)
Expand All @@ -821,8 +820,9 @@ def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr:
self.parse_punctuation(">", " in dense attribute")

# Parse the dense type and check for correctness
self.parse_punctuation(":", " in dense attribute")
type = self._parse_dense_literal_type()
if type is None:
self.parse_punctuation(":", " in dense attribute")
type = self._parse_dense_literal_type()
type_shape = list(type.get_shape())
type_num_values = math.prod(type_shape)

Expand Down Expand Up @@ -866,6 +866,9 @@ def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr:

return DenseIntOrFPElementsAttr.from_list(type, data_values)

def _parse_builtin_dense_attr(self, _name: Span) -> DenseIntOrFPElementsAttr:
return self.parse_dense_int_or_fp_elements_attr(None)

def _parse_builtin_opaque_attr(self, _name: Span):
str_lit_list = self.parse_comma_separated_list(
self.Delimiter.ANGLE, self.parse_str_literal
Expand Down
48 changes: 0 additions & 48 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
AffineMapAttr,
AffineSetAttr,
AnyFloatAttr,
AnyIntegerAttr,
AnyUnrankedMemrefType,
AnyUnrankedTensorType,
AnyVectorType,
Expand All @@ -23,15 +22,13 @@
BytesAttr,
ComplexType,
DenseArrayBase,
DenseIntOrFPElementsAttr,
DenseResourceAttr,
DictionaryAttr,
Float16Type,
Float32Type,
Float64Type,
Float80Type,
Float128Type,
FloatAttr,
FloatData,
FunctionType,
IndexType,
Expand Down Expand Up @@ -612,51 +609,6 @@ def print_attribute(self, attribute: Attribute) -> None:
self.print_string(")")
return

if isinstance(attribute, DenseIntOrFPElementsAttr):

def print_one_elem(val: Attribute):
if isinstance(val, IntegerAttr):
self.print_string(f"{val.value.data}")
elif isinstance(val, FloatAttr):
self.print_float(cast(AnyFloatAttr, val))
else:
raise Exception(
"unexpected attribute type "
"in DenseIntOrFPElementsAttr: "
f"{type(val)}"
)

def print_dense_list(
array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
shape: Sequence[int],
):
self.print_string("[")
if len(shape) > 1:
k = len(array) // shape[0]
self.print_list(
(array[i : i + k] for i in range(0, len(array), k)),
lambda subarray: print_dense_list(subarray, shape[1:]),
)
else:
self.print_list(array, print_one_elem)
self.print_string("]")

self.print_string("dense<")
data = attribute.data.data
shape = (
attribute.get_shape() if attribute.shape_is_complete else (len(data),)
)
assert shape is not None, "If shape is complete, then it cannot be None"
if len(data) == 0:
pass
elif data.count(data[0]) == len(data):
print_one_elem(data[0])
else:
print_dense_list(data, shape)
self.print_string("> : ")
self.print_attribute(attribute.type)
return

if isinstance(attribute, DenseResourceAttr):
handle = attribute.resource_handle.data
self.print_string(f"dense_resource<{handle}> : ")
Expand Down

0 comments on commit 018d17f

Please sign in to comment.