Skip to content

Commit

Permalink
core: make BaseAttr inferrable (#3491)
Browse files Browse the repository at this point in the history
As discussed in #3477 

Fixes #3483
  • Loading branch information
alexarice authored Nov 20, 2024
1 parent 0c871d9 commit 746db81
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
34 changes: 33 additions & 1 deletion tests/irdl/test_attr_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from xdsl.dialects.builtin import StringAttr
from xdsl.ir import Attribute, ParametrizedAttribute
from xdsl.ir import Attribute, Data, ParametrizedAttribute
from xdsl.irdl import (
AllOf,
AnyAttr,
Expand Down Expand Up @@ -105,3 +105,35 @@ class WrapAttr(BaseWrapAttr): ...
),
)
assert not base_constr.can_infer(set())


def test_base_attr_constraint_inference():
class BaseNoParamAttr(ParametrizedAttribute):
name = "no_param"

@irdl_attr_definition
class WithParamAttr(ParametrizedAttribute):
name = "with_param"

inner: ParameterDef[Attribute]

@irdl_attr_definition
class DataAttr(Data[int]):
name = "data"

@irdl_attr_definition
class NoParamAttr(BaseNoParamAttr): ...

constr = BaseAttr(NoParamAttr)

assert constr.can_infer(set())
assert constr.infer({}) == NoParamAttr()

base_constr = BaseAttr(BaseNoParamAttr)
assert not base_constr.can_infer(set())

with_param_constr = BaseAttr(WithParamAttr)
assert not with_param_constr.can_infer(set())

data_constr = BaseAttr(DataAttr)
assert not data_constr.can_infer(set())
12 changes: 12 additions & 0 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,18 @@ def verify(
f"{attr} should be of base attribute {self.attr.name}"
)

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return (
is_runtime_final(self.attr)
and issubclass(self.attr, ParametrizedAttribute)
and not self.attr.get_irdl_definition().parameters
)

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
assert issubclass(self.attr, ParametrizedAttribute)
attr = self.attr.new(())
return attr

def get_unique_base(self) -> type[Attribute] | None:
if is_runtime_final(self.attr):
return self.attr
Expand Down

0 comments on commit 746db81

Please sign in to comment.