Skip to content

Commit

Permalink
core: with-type-constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 19, 2024
1 parent badc71b commit 9fde882
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 61 deletions.
34 changes: 17 additions & 17 deletions tests/filecheck/dialects/fsm/fsm_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@
"fsm.transition"() ({
"fsm.return"(%arg0) : (i1) -> ()
}, {

}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()
}) {function_type = (i1) -> (i1), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> ()

Expand All @@ -160,7 +160,7 @@
"fsm.output"() : () -> ()
}, {
"fsm.transition"() ({
^bb1(%arg3: i1):
^bb1(%arg3: i1):
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
"fsm.output"() : () -> ()
}, {
Expand All @@ -179,16 +179,16 @@

"fsm.state"() ({
"fsm.output"() : () -> ()

}, {
"fsm.transition"() ({

}, {
^bb1(%arg3: i1):
^bb1(%arg3: i1):
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()
}) {function_type = () -> (), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> ()

Expand All @@ -203,10 +203,10 @@

}, {
"fsm.transition"() ({

}, {
}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()

}) {function_type = (i16) -> (i16) , initialState = "A", sym_name = "foo"} : () -> ()
Expand Down Expand Up @@ -257,7 +257,7 @@
}, {
}) {nextState = @A} : () -> ()
}) {sym_name = "A"} : () -> ()

}) {function_type = (i16) -> (i1), initialState = "A", sym_name = "foo"} : () -> ()
%arg1 = "arith.constant"() {value = 0 : i16} : () -> i16
%arg2 = "arith.constant"() {value = 0 : i16} : () -> i16
Expand Down Expand Up @@ -328,12 +328,12 @@
}) {nextState = @C} : () -> ()
}) {sym_name = "C"} : () -> ()
}) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> ()

"func.func"() ({
%3 = "arith.constant"() {value = 16: i16} : () -> i16

%4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i16
%1 = "arith.constant"() {value = 0 : i16} : () -> i16
%2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "qux"} : () -> ()
Expand Down Expand Up @@ -371,10 +371,10 @@
}) {nextState = @C} : () -> ()
}) {sym_name = "C"} : () -> ()
}) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> ()

"func.func"() ({
%3 = "arith.constant"() {value = 16: i16} : () -> i16

%4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i1
%2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i16
Expand All @@ -391,8 +391,8 @@
%0 = "fsm.variable"() {initValue = 0 : i16, name = "cnt"} : () -> i16
"fsm.machine"() ({
%4 = "test.op"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i16
%2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1
%1 = "arith.constant"() {value = true} : () -> i1
%2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i1
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "qux"} : () -> ()

Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/runner/factorial.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ builtin.module {
"func.return"(%ret) : (i64) -> ()
}
func.func @main() -> index {
%zero = "arith.constant"() {"value" = 0} : () -> index
%zero = "arith.constant"() {"value" = 0 : index} : () -> index
%i = "arith.constant"() {"value" = 12} : () -> i64
%fac = "func.call"(%i) {"callee" = @factorial} : (i64) -> i64
printf.print_format "factorial({})={}", %i : i64, %fac : i64
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/runner/with-wgpu/global_id_inc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ builtin.module attributes {gpu.container_module} {
%hmemref = "memref.alloc"() {"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 0, 0>} : () -> memref<4x4xindex>
"gpu.memcpy"(%hmemref, %memref) {"operandSegmentSizes" = array<i32: 0, 1, 1>} : (memref<4x4xindex>, memref<4x4xindex>) -> ()
printf.print_format "Result : {}", %hmemref : memref<4x4xindex>
%zero = "arith.constant"() {"value" = 0} : () -> (index)
%zero = "arith.constant"() {"value" = 0 : index} : () -> (index)
"func.return"(%zero) : (index) -> ()
}
}
Expand Down
24 changes: 12 additions & 12 deletions tests/filecheck/transforms/function-constant-pinning.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


func.func @basic() -> i32 {
%v = "test.op"() {pin_to_constants = [0]} : () -> i32
%v = "test.op"() {pin_to_constants = [0 : i32]} : () -> i32
func.return %v : i32
}

Expand All @@ -11,7 +11,7 @@ func.func @basic() -> i32 {
// CHECK-NEXT: func.func @basic() -> i32 {
// CHECK-NEXT: %v = "test.op"() : () -> i32
// compare the value to the constant we want to specialize for
// CHECK-NEXT: %0 = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// if they are equal, branch to specialized function
Expand All @@ -25,7 +25,7 @@ func.func @basic() -> i32 {
// specialized function here
// CHECK-NEXT: func.func @basic_pinned() -> i32 {
// original op is replaced by constant instantiation
// CHECK-NEXT: %v = arith.constant 0 : i64
// CHECK-NEXT: %v = arith.constant 0 : i32
// CHECK-NEXT: func.return %v : i32
// CHECK-NEXT: }

Expand Down Expand Up @@ -79,7 +79,7 @@ func.func @control_flow() {


func.func @function_args(%arg0: memref<100xf32>) -> i32 {
%v = "test.op"() {pin_to_constants = [0]} : () -> i32
%v = "test.op"() {pin_to_constants = [0 : i32]} : () -> i32

"test.op"(%v, %arg0) : (i32, memref<100xf32>) -> ()

Expand All @@ -89,7 +89,7 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 {

// CHECK-NEXT: func.func @function_args(%arg0 : memref<100xf32>) -> i32 {
// CHECK-NEXT: %v = "test.op"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// make sure that we forward function args to the specialized function
Expand All @@ -103,7 +103,7 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 {
// CHECK-NEXT: func.return %2 : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func @function_args_pinned(%arg0 : memref<100xf32>) -> i32 {
// CHECK-NEXT: %v = arith.constant 0 : i64
// CHECK-NEXT: %v = arith.constant 0 : i32
// here the function arg is used
// CHECK-NEXT: "test.op"(%v, %arg0) : (i32, memref<100xf32>) -> ()
// CHECK-NEXT: func.return %v : i32
Expand Down Expand Up @@ -155,7 +155,7 @@ func.func @control_flow_and_function_args(%arg: i32) -> i32 {


func.func @specialize_multi_case() -> i32 {
%v = "test.op"() {pin_to_constants = [0, 1]} : () -> i32
%v = "test.op"() {pin_to_constants = [0 : i32, 1 : i32]} : () -> i32
func.return %v : i32
}

Expand All @@ -164,13 +164,13 @@ func.func @specialize_multi_case() -> i32 {

// CHECK-NEXT: func.func @specialize_multi_case() -> i32 {
// CHECK-NEXT: %v = "test.op"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned_1() : () -> i32
// CHECK-NEXT: scf.yield %3 : i32
// CHECK-NEXT: } else {
// CHECK-NEXT: %4 = arith.constant 1 : i64
// CHECK-NEXT: %4 = arith.constant 1 : i32
// CHECK-NEXT: %5 = arith.cmpi eq, %v, %4 : i32
// CHECK-NEXT: %6 = scf.if %5 -> (i32) {
// CHECK-NEXT: %7 = func.call @specialize_multi_case_pinned() : () -> i32
Expand All @@ -185,8 +185,8 @@ func.func @specialize_multi_case() -> i32 {
// CHECK-NEXT: func.func @specialize_multi_case_pinned_1() -> i32 {
// this function still carries the old specialization check within it, but MLIR can see that
// the branch is never taken, so it's completely removed.
// CHECK-NEXT: %v = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 1 : i64
// CHECK-NEXT: %v = arith.constant 0 : i32
// CHECK-NEXT: %0 = arith.constant 1 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned() : () -> i32
Expand All @@ -197,7 +197,7 @@ func.func @specialize_multi_case() -> i32 {
// CHECK-NEXT: func.return %2 : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func @specialize_multi_case_pinned() -> i32 {
// CHECK-NEXT: %v = arith.constant 1 : i64
// CHECK-NEXT: %v = arith.constant 1 : i32
// CHECK-NEXT: func.return %v : i32
// CHECK-NEXT: }

Expand Down
44 changes: 16 additions & 28 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
AnyIntegerAttrConstr,
ContainerOf,
DenseIntOrFPElementsAttr,
Float16Type,
Expand All @@ -25,9 +26,12 @@
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyAttr,
AnyOf,
BaseAttr,
IRDLOperation,
VarConstraint,
WithTypeConstraint,
base,
irdl_attr_definition,
irdl_op_definition,
Expand All @@ -48,7 +52,6 @@
Pure,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.isattr import isattr
from xdsl.utils.str_enum import StrEnum

boolLike = ContainerOf(IntegerType(1))
Expand Down Expand Up @@ -124,11 +127,21 @@ def __init__(self, flags: None | Sequence[IntegerOverflowFlag] | Literal["none"]
@irdl_op_definition
class Constant(IRDLOperation):
name = "arith.constant"
result = result_def(Attribute)
value = prop_def(Attribute)
_T: ClassVar = VarConstraint("T", AnyAttr())
result = result_def(_T)
value = prop_def(
WithTypeConstraint(
AnyIntegerAttrConstr
| BaseAttr[FloatAttr[AnyFloat]](FloatAttr)
| BaseAttr(DenseIntOrFPElementsAttr),
_T,
)
)

traits = traits_def(ConstantLike(), Pure())

assembly_format = "attr-dict $value"

@overload
def __init__(
self,
Expand Down Expand Up @@ -162,31 +175,6 @@ def from_int_and_width(
properties={"value": IntegerAttr(value, value_type)},
)

def print(self, printer: Printer):
printer.print_op_attributes(self.attributes)

printer.print(" ")
printer.print_attribute(self.value)

@classmethod
def parse(cls: type[Constant], parser: Parser) -> Constant:
attrs = parser.parse_optional_attr_dict()

p0 = parser.pos
value = parser.parse_attribute()

if not isattr(
value,
base(AnyIntegerAttr)
| base(FloatAttr[AnyFloat])
| base(DenseIntOrFPElementsAttr),
):
parser.raise_error("Invalid constant value", p0, parser.pos)

c = Constant(value)
c.attributes.update(attrs)
return c


_T = TypeVar("_T", bound=Attribute)

Expand Down
15 changes: 13 additions & 2 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
ParamAttrConstraint,
ParameterDef,
VarExtractor,
WithType,
attr_constr_coercion,
base,
irdl_attr_definition,
Expand Down Expand Up @@ -450,6 +451,7 @@ class IndexType(ParametrizedAttribute):
class IntegerAttr(
Generic[_IntegerAttrType],
TypedAttribute[_IntegerAttrType],
WithType,
):
name = "integer"
value: ParameterDef[IntAttr]
Expand Down Expand Up @@ -516,6 +518,9 @@ def parse_with_type(
def print_without_type(self, printer: Printer):
return printer.print(self.value.data)

def get_type(self) -> Attribute:
return self.type

@staticmethod
def constr(
*,
Expand Down Expand Up @@ -638,7 +643,7 @@ def __hash__(self):


@irdl_attr_definition
class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute):
class FloatAttr(Generic[_FloatAttrType], ParametrizedAttribute, WithType):
name = "float"

value: ParameterDef[FloatData]
Expand Down Expand Up @@ -672,6 +677,9 @@ def __init__(
raise ValueError(f"Invalid bitwidth: {type}")
super().__init__([data_attr, type])

def get_type(self) -> Attribute:
return self.type


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)
Expand Down Expand Up @@ -1711,7 +1719,7 @@ def get_element_type(self) -> _UnrankedMemrefTypeElems:

@irdl_attr_definition
class DenseIntOrFPElementsAttr(
ParametrizedAttribute, ContainerType[IntegerType | IndexType | AnyFloat]
ParametrizedAttribute, ContainerType[IntegerType | IndexType | AnyFloat], WithType
):
name = "dense"
type: ParameterDef[
Expand Down Expand Up @@ -1864,6 +1872,9 @@ def tensor_from_list(
t = TensorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)

def get_type(self) -> Attribute:
return self.type


Builtin = Dialect(
"builtin",
Expand Down
2 changes: 2 additions & 0 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,8 @@ def parse_with_type(
@abstractmethod
def print_without_type(self, printer: Printer): ...

def get_type(self) -> Attribute: ...


@dataclass(init=False)
class IRNode(ABC):
Expand Down
Loading

0 comments on commit 9fde882

Please sign in to comment.