Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: introduce param_def API for ParametrizedAttribute declaration #3444

Draft
wants to merge 63 commits into
base: sasha/misc/type-var-constr-merge
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
8e3df95
add TypeVarConstraint to store type information
superlopuh Sep 4, 2024
76de217
wip wip iwp
superlopuh Sep 11, 2024
adcc2f8
WIP it works need to split up
superlopuh Sep 17, 2024
12136c6
Merge branch 'main' into sasha/misc/type-var-constr
superlopuh Sep 17, 2024
6d65312
misc: remove redundant `AnyAttr()` constraints from op definitions
superlopuh Sep 17, 2024
adb3882
Merge branch 'sasha/misc/no-any-attr-in-def' into sasha/misc/type-var…
superlopuh Sep 17, 2024
21d2f80
misc: raise instead of returning error in MemrefLayoutAttr.get_affine…
superlopuh Sep 17, 2024
697da11
Merge branch 'sasha/misc/raise-not-implemented' into sasha/misc/type-…
superlopuh Sep 17, 2024
1ef3492
wip works
superlopuh Sep 17, 2024
6dfd769
api: (irdl) ParamAttrConstraint is a BaseAttr constraint, not a type
superlopuh Sep 17, 2024
c508e80
Merge branch 'sasha/irdl/param-attr-constraint-base' into sasha/misc/…
superlopuh Sep 17, 2024
3ee0a70
use BaseAttr for StencilTypeConstr
superlopuh Sep 17, 2024
7629bcb
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
3dce048
Merge branch 'sasha/stencil/stencil-type-constr' into sasha/misc/type…
superlopuh Sep 17, 2024
7c10bca
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
fa2427b
Merge branch 'sasha/stencil/stencil-type-constr' into sasha/misc/type…
superlopuh Sep 17, 2024
9c47a82
use more AnyMemRefTypeConstr and AnyTensorTypeConstr
superlopuh Sep 18, 2024
77faf83
dialects: (builtin) add AnyMemRefTypeConstr and AnyTensorTypeConstr
superlopuh Sep 18, 2024
a8dfc11
Merge branch 'sasha/misc/any-memref-tensor-type-constr' into sasha/mi…
superlopuh Sep 18, 2024
a5a09f8
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
f1e2898
Merge branch 'sasha/stencil/stencil-type-constr' into sasha/misc/type…
superlopuh Sep 18, 2024
c46787f
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
9c61860
Merge branch 'sasha/stencil/stencil-type-constr' into sasha/misc/type…
superlopuh Sep 18, 2024
bb94566
no need for isinstance of TypeVarConstraint in irdl_to_attr_constraint
superlopuh Sep 18, 2024
6bec340
use new stencil type constraints in param attr constraints
superlopuh Sep 18, 2024
5959a8d
dialects: (builtin) add AnyMemRefTypeConstr and AnyTensorTypeConstr
superlopuh Sep 18, 2024
51b6d07
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
c31fa27
api: (irdl) ParamAttrConstraint is a BaseAttr constraint, not a type
superlopuh Sep 17, 2024
9f4b2f7
Merge branch 'sasha/irdl/param-attr-constraint-base' into sasha/misc/…
superlopuh Sep 18, 2024
8d33d04
revert unnecesary changes
superlopuh Sep 18, 2024
2893129
accept abstract base Attribute classes also
superlopuh Sep 18, 2024
2a4f14e
dialects: (builtin) add AnyMemRefTypeConstr and AnyTensorTypeConstr
superlopuh Sep 18, 2024
ee9980b
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
3d3bf4c
api: (irdl) ParamAttrConstraint is a BaseAttr constraint, not a type
superlopuh Sep 17, 2024
38dd699
Merge branch 'sasha/irdl/param-attr-constraint-base' into sasha/misc/…
superlopuh Oct 4, 2024
c6b190b
api: (irdl) ParamAttrConstraint base_attr -> base_constr
superlopuh Oct 4, 2024
d9e5859
Merge branch 'sasha/irdl/param-attr-constraint-base' into sasha/misc/…
superlopuh Oct 4, 2024
fd153eb
format
superlopuh Oct 6, 2024
6a32589
dialects: (stencil) add StencilTypeConstr
superlopuh Sep 17, 2024
90d1c93
api: (irdl) ParamAttrConstraint base_attr -> base_constr
superlopuh Oct 4, 2024
4c79972
Merge branch 'sasha/irdl/param-attr-constraint-base' into sasha/misc/…
superlopuh Oct 6, 2024
bf81d8c
Merge branch 'main' into sasha/misc/type-var-constr
superlopuh Nov 13, 2024
e5eb94c
revert reorder
superlopuh Nov 13, 2024
56102f5
wip wip wip
superlopuh Nov 13, 2024
691c1a6
core: add attr_constr_rewrite_pattern
superlopuh Nov 13, 2024
7a8b187
Merge branch 'sasha/pyrdl/attr-constr-rewrite-pattern' into sasha/mis…
superlopuh Nov 13, 2024
8c6d4bc
use constr pattern in stencil type conversion
superlopuh Nov 13, 2024
52a5e43
dialects: (stream) add constr to StreamType base class
superlopuh Nov 13, 2024
2dd1b0f
Merge branch 'main' into sasha/misc/type-var-constr
superlopuh Nov 13, 2024
616a67b
Merge branch 'sasha/stream/constr' into sasha/misc/type-var-constr
superlopuh Nov 13, 2024
b203e5a
Merge branch 'main' into sasha/stream/constr
superlopuh Nov 13, 2024
6b8686e
Merge branch 'sasha/stream/constr' into sasha/misc/type-var-constr
superlopuh Nov 13, 2024
0937de7
use constr correctly
superlopuh Nov 13, 2024
1b51b4d
make stream constr static again
superlopuh Nov 13, 2024
9e4aad1
Merge branch 'sasha/stream/constr' into sasha/misc/type-var-constr
superlopuh Nov 13, 2024
1387b91
remove unnecessary type annotation
superlopuh Nov 13, 2024
41ec646
Merge branch 'main' into sasha/misc/type-var-constr
superlopuh Nov 13, 2024
ead530c
wip need some more tests and to check that the constraint creation works
superlopuh Jul 25, 2024
df9395b
wip test works
superlopuh Nov 13, 2024
7224141
Merge branch 'sasha/misc/type-var-constr-merge' into sasha/irdl/typed…
superlopuh Nov 13, 2024
7167cf4
remove some casts
superlopuh Nov 13, 2024
a904f29
add a test for mixed use
superlopuh Nov 13, 2024
827477f
Merge branch 'sasha/misc/type-var-constr-merge' into sasha/irdl/typed…
superlopuh Nov 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions tests/irdl/test_attribute_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
base,
irdl_attr_definition,
irdl_to_attr_constraint,
param_def,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
Expand Down Expand Up @@ -816,6 +817,26 @@ def test_irdl_definition():
)


@irdl_attr_definition
class ParamAttrDefAttr2(ParametrizedAttribute):
name = "test.param_attr_def_attr"

arg1 = param_def(Attribute)
arg2 = param_def(BoolData)

# Check that we can define methods in attribute definition
def test(self):
pass


def test_irdl_definition2():
"""Test that we can get the IRDL definition of a parametrized attribute."""

assert ParamAttrDefAttr2.get_irdl_definition() == ParamAttrDef(
"test.param_attr_def_attr", [("arg1", AnyAttr()), ("arg2", BaseAttr(BoolData))]
)


class InvalidTypedFieldTestAttr(ParametrizedAttribute):
name = "test.invalid_typed_field"

Expand Down Expand Up @@ -899,6 +920,50 @@ def test_generic_attr():
)


@irdl_attr_definition
class GenericAttr2(Generic[AttributeInvT], ParametrizedAttribute):
name = "test.generic_attr"

param = param_def(AttributeInvT)


def test_generic_attr2():
"""Test the generic parameter of a ParametrizedAttribute."""

assert GenericAttr2.get_irdl_definition() == ParamAttrDef(
"test.generic_attr",
[
(
"param",
TypeVarConstraint(
type_var=AttributeInvT,
constraint=AnyAttr(),
),
)
],
)

assert base(GenericAttr2[IntAttr]) == ParamAttrConstraint(
GenericAttr2, (BaseAttr(IntAttr),)
)


def test_mixed_param_def_apis():
"""Test that using both ParamDef and param_def raises an error."""
with pytest.raises(
ValueError,
match="ParametrizedAttribute definitions must not mix `param_def` and ParamDef "
"declarations.",
):

@irdl_attr_definition
class InvalidAttr(ParametrizedAttribute): # pyright: ignore[reportUnusedClass]
name = "test.invalid"
# Using both styles is invalid
param1: ParameterDef[IntegerType] # Using annotation style
param2 = param_def(IntegerType) # Using param_def style


################################################################################
# ConstraintVar
################################################################################
Expand Down
5 changes: 4 additions & 1 deletion xdsl/dialects/riscv_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue
from xdsl.irdl import (
BaseAttr,
VarConstraint,
attr_def,
base,
Expand Down Expand Up @@ -457,7 +458,9 @@ def assembly_instruction_name(self) -> str:
class GetStreamOp(RISCVAsmOperation):
name = "riscv_snitch.get_stream"

stream = result_def(stream.StreamType[riscv.FloatRegisterType])
stream = result_def(
stream.StreamType.constr(element_type=BaseAttr(riscv.FloatRegisterType))
)

def __init__(self, result_type: Attribute):
super().__init__(result_types=[result_type])
Expand Down
98 changes: 80 additions & 18 deletions xdsl/irdl/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
Annotated,
Any,
Generic,
Literal,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
)

from xdsl.ir import (
Expand Down Expand Up @@ -81,12 +83,56 @@ def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint:

_A = TypeVar("_A", bound=Attribute)


# Will deprecate soon
ParameterDef = Annotated[_A, IRDLAnnotations.ParamDefAnnot]


def irdl_param_attr_get_param_type_hints(cls: type[_A]) -> list[tuple[str, Any]]:
# Field definition classes for `@irdl_param_attr_definition`
class _ParameterDef:
param: AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar

def __init__(
self,
param: AttrConstraint | Attribute | type[Attribute] | TypeVar | ConstraintVar,
):
self.param = param


@overload
def param_def(
constraint: TypeVar,
*,
default: None = None,
resolver: None = None,
init: Literal[False] = False,
) -> Attribute: ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't feel ideal? We effectively lose the python typing for the field which was kind of the whole point of having the typevar

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep and it's currently breaking my code when I migrate xDSL to it. I'll probably change this soon.



@overload
def param_def(
constraint: type[AttributeInvT] | GenericAttrConstraint[AttributeInvT],
*,
default: None = None,
resolver: None = None,
init: Literal[False] = False,
) -> AttributeInvT: ...


def param_def(
constraint: type[AttributeInvT] | TypeVar | GenericAttrConstraint[AttributeInvT],
*,
default: None = None,
resolver: None = None,
init: Literal[False] = False,
) -> AttributeInvT:
"""Defines a property of an operation."""
return cast(AttributeInvT, _ParameterDef(constraint))


def irdl_param_attr_get_param_type_hints(cls: type[_A]) -> dict[str, Any]:
"""Get the type hints of an IRDL parameter definitions."""
res = list[tuple[str, Any]]()
res: dict[str, Any] = {}
for field_name, field_type in get_type_hints(cls, include_extras=True).items():
if field_name == "name" or field_name == "parameters":
continue
Expand All @@ -100,7 +146,7 @@ def irdl_param_attr_get_param_type_hints(cls: type[_A]) -> list[tuple[str, Any]]
+ f"type `ParameterDef[<Constraint>]`, got type {field_type}."
)

res.append((field_name, field_type))
res[field_name] = field_type
return res


Expand Down Expand Up @@ -133,16 +179,28 @@ def from_pyrdl(
for key, value in parent_cls.__dict__.items()
if key not in _PARAMETRIZED_ATTRIBUTE_DICT_KEYS
}
# Fields of the class with type annotations
param_hints = irdl_param_attr_get_param_type_hints(pyrdl_def)

# The resulting parameters
parameters = list[tuple[str, AttrConstraint]]()

# Check that all fields of the attribute definition are either already
# in ParametrizedAttribute, or are class functions or methods.
for field_name, value in clsdict.items():
if field_name == "name":
# Ignore name field
continue
if isinstance(
value, FunctionType | PropertyType | classmethod | staticmethod
):
# Ignore functions
continue

# Parameter def must be a field def
if isinstance(value, _ParameterDef):
constraint = irdl_to_attr_constraint(value.param, allow_type_var=True)
parameters.append((field_name, constraint))
continue

# Constraint variables are allowed
if get_origin(value) is Annotated:
if any(isinstance(arg, ConstraintVar) for arg in get_args(value)):
Expand All @@ -160,17 +218,13 @@ def from_pyrdl(

name = clsdict["name"]

param_hints = irdl_param_attr_get_param_type_hints(pyrdl_def)
if issubclass(pyrdl_def, TypedAttribute):
pyrdl_def = cast(type[TypedAttribute[Attribute]], pyrdl_def)
try:
param_names = [name for name, _ in param_hints]
type_index = param_names.index("type")
except ValueError:
if "type" not in param_hints:
raise PyRDLAttrDefinitionError(
f"TypedAttribute {pyrdl_def.__name__} should have a 'type' parameter."
)
typed_hint = param_hints[type_index][1]
typed_hint = param_hints["type"]
if get_origin(typed_hint) is Annotated:
typed_hint = get_args(typed_hint)[0]
type_var = get_type_var_mapping(pyrdl_def)[1][AttributeCovT]
Expand All @@ -181,8 +235,13 @@ def from_pyrdl(
" as the type variable in the TypedAttribute base class."
)

parameters = list[tuple[str, AttrConstraint]]()
for param_name, param_type in param_hints:
if parameters and param_hints:
raise ValueError(
"ParametrizedAttribute definitions must not mix `param_def` and "
"ParamDef declarations."
)

for param_name, param_type in param_hints.items():
constraint = irdl_to_attr_constraint(param_type, allow_type_var=True)
parameters.append((param_name, constraint))

Expand Down Expand Up @@ -310,7 +369,8 @@ def irdl_to_attr_constraint(
allow_type_var: bool = False,
) -> AttrConstraint:
if isinstance(irdl, GenericAttrConstraint):
return cast(AttrConstraint, irdl)
constr: GenericAttrConstraint[Attribute] = irdl
return constr

if isinstance(irdl, Attribute):
return EqAttrConstraint(irdl)
Expand Down Expand Up @@ -357,9 +417,8 @@ def irdl_to_attr_constraint(
args = get_args(irdl)
if len(args) != 1:
raise Exception(f"GenericData args must have length 1, got {args}")
origin = cast(type[GenericData[Any]], origin)
args = cast(tuple[Attribute], args)
return AllOf((BaseAttr(origin), origin.generic_constraint_coercion(args)))
cls: type[GenericData[Attribute]] = origin
return AllOf((BaseAttr(cls), origin.generic_constraint_coercion(args)))

# Generic ParametrizedAttributes case
# We translate it to constraints over the attribute parameters.
Expand All @@ -383,7 +442,10 @@ def irdl_to_attr_constraint(

type_var_mapping = dict(zip(generic_args, args))

origin_parameters = irdl_param_attr_get_param_type_hints(origin)
# Map the constraints in the attribute definition
attr_def = origin.get_irdl_definition()
origin_parameters = attr_def.parameters

origin_constraints = [
irdl_to_attr_constraint(param, allow_type_var=True).mapping_type_vars(
type_var_mapping
Expand Down
3 changes: 2 additions & 1 deletion xdsl/transforms/experimental/convert_stencil_to_ll_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
PatternRewriteWalker,
RewritePattern,
TypeConversionPattern,
attr_constr_rewrite_pattern,
attr_type_rewrite_pattern,
op_type_rewrite_pattern,
)
Expand Down Expand Up @@ -649,7 +650,7 @@ def return_target_analysis(module: builtin.ModuleOp):


class StencilTypeConversion(TypeConversionPattern):
@attr_type_rewrite_pattern
@attr_constr_rewrite_pattern(StencilTypeConstr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a different change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's more related to the previous change, let me revert this and see what happens

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope it's related to this change, since we use the attr_def instead of the type hints to construct the constraint, we can't do it on the ABC that is StencilTypeConstr since it's not really an attribute. We have to construct the constraint manually instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly in the change in riscv_snitch

def convert_type(self, typ: StencilType[Attribute]) -> MemRefType[Attribute]:
return StencilToMemRefType(typ)

Expand Down
Loading