Skip to content

Commit

Permalink
misc: use nullable getters in ConstraintContext (#3503)
Browse files Browse the repository at this point in the history
Paring this PR down until it's clear that ScopedDicts are what we want
to use for constraint contexts, reducing changes to just getters being
nullable.
  • Loading branch information
superlopuh authored Nov 22, 2024
1 parent 2a7fa60 commit 4c530f8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
5 changes: 3 additions & 2 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,9 +1214,10 @@ def matches(attr: TensorType[Attribute], other: Attribute) -> bool:
)

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if self.name in constraint_context.variables:
ctx_attr = constraint_context.get_variable(self.name)
if ctx_attr is not None:
if isa(attr, TensorType[Attribute]) and TensorIgnoreSizeConstraint.matches(
attr, constraint_context.get_variable(self.name)
attr, ctx_attr
):
return
super().verify(attr, constraint_context)
Expand Down
20 changes: 11 additions & 9 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class ConstraintContext:
_range_variables: dict[str, tuple[Attribute, ...]] = field(default_factory=dict)
"""The assignment of constraint range variables."""

def get_variable(self, key: str) -> Attribute:
return self._variables[key]
def get_variable(self, key: str) -> Attribute | None:
return self._variables.get(key)

def get_range_variable(self, key: str) -> tuple[Attribute, ...]:
return self._range_variables[key]
def get_range_variable(self, key: str) -> tuple[Attribute, ...] | None:
return self._range_variables.get(key)

def set_variable(self, key: str, attr: Attribute):
self._variables[key] = attr
Expand Down Expand Up @@ -224,8 +224,9 @@ def verify(
attr: Attribute,
constraint_context: ConstraintContext,
) -> None:
if self.name in constraint_context.variables:
if attr != constraint_context.get_variable(self.name):
ctx_attr = constraint_context.get_variable(self.name)
if ctx_attr is not None:
if attr != ctx_attr:
raise VerifyException(
f"attribute {constraint_context.get_variable(self.name)} expected from variable "
f"'{self.name}', but got {attr}"
Expand Down Expand Up @@ -671,10 +672,11 @@ def verify(
attrs: Sequence[Attribute],
constraint_context: ConstraintContext,
) -> None:
if self.name in constraint_context.range_variables:
if tuple(attrs) != constraint_context.get_range_variable(self.name):
ctx_attrs = constraint_context.get_range_variable(self.name)
if ctx_attrs is not None:
if attrs != ctx_attrs:
raise VerifyException(
f"attributes {tuple(str(x) for x in constraint_context.get_range_variable(self.name))} expected from range variable "
f"attributes {tuple(str(x) for x in ctx_attrs)} expected from range variable "
f"'{self.name}', but got {tuple(str(x) for x in attrs)}"
)
else:
Expand Down

0 comments on commit 4c530f8

Please sign in to comment.