Skip to content

Commit

Permalink
Fix bufferization
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 19, 2024
1 parent bdc8420 commit aeadd84
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
5 changes: 2 additions & 3 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from xdsl.dialects.test import TestOp
from xdsl.ir import Attribute
from xdsl.irdl import (
ConstraintContext,
EqAttrConstraint,
IRDLOperation,
VarConstraint,
Expand All @@ -40,13 +39,13 @@ def test_tensor_from_memref_inference():
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer(ConstraintContext()) == TensorType(f64, [10, 20, 30])
assert constr2.infer(dict()) == TensorType(f64, [10, 20, 30])

constr3 = TensorFromMemrefConstraint(
EqAttrConstraint(UnrankedMemrefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer(ConstraintContext()) == UnrankedTensorType(f64)
assert constr3.infer(dict()) == UnrankedTensorType(f64)


@irdl_op_definition
Expand Down
10 changes: 6 additions & 4 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def infer(
def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if isa(attr, TensorType[Attribute]):
memref_type = MemRefType(attr.element_type, attr.shape)
return self.memref_constraint.verify(memref_type, constraint_context)
if isa(attr, UnrankedTensorType[Attribute]):
elif isa(attr, UnrankedTensorType[Attribute]):
memref_type = UnrankedMemrefType.from_type(attr.element_type)

raise VerifyException(f"Expected TensorType or UnrankedTensorType, got {attr}")
else:
raise VerifyException(
f"Expected tensor or unranked tensor type, got {attr}"
)
return self.memref_constraint.verify(memref_type, constraint_context)


@irdl_op_definition
Expand Down

0 comments on commit aeadd84

Please sign in to comment.