Skip to content

Commit

Permalink
EqsatPDLRewriteFunctions inherit from PDLRewriteFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Nov 18, 2024
1 parent 63382a0 commit 5111b75
Showing 1 changed file with 3 additions and 29 deletions.
32 changes: 3 additions & 29 deletions xdsl/interpreters/eqsat_pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from xdsl.context import MLContext
from xdsl.dialects import eqsat, pdl
from xdsl.dialects.builtin import ModuleOp
from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls
from xdsl.interpreters.pdl import PDLMatcher
from xdsl.interpreter import Interpreter, impl, register_impls
from xdsl.interpreters.pdl import PDLMatcher, PDLRewriteFunctions
from xdsl.ir import Attribute, Operation, OpResult, SSAValue, TypeAttribute
from xdsl.irdl import IRDLOperation
from xdsl.pattern_rewriter import PatternRewriter, RewritePattern
Expand Down Expand Up @@ -104,16 +104,14 @@ def match_and_rewrite(self, xdsl_op: Operation, rewriter: PatternRewriter) -> No

@register_impls
@dataclass
class EqsatPDLRewriteFunctions(InterpreterFunctions):
class EqsatPDLRewriteFunctions(PDLRewriteFunctions):
"""
The implementations in this class are for the RHS of the rewrite. The SSA values
referenced within the rewrite block are guaranteed to have been matched with the
corresponding IR elements. The interpreter context stores the IR elements by SSA
values.
"""

ctx: MLContext
_rewriter: PatternRewriter | None = field(default=None)
value_to_eclass: dict[SSAValue, eqsat.EClassOp] = field(default_factory=dict)
operation_to_eclass: dict[
tuple[str, tuple[tuple[str, Attribute], ...], tuple[SSAValue, ...]],
Expand All @@ -136,15 +134,6 @@ def populate_maps(self, module: ModuleOp):
] = op
self.did_populate = True

@property
def rewriter(self) -> PatternRewriter:
assert self._rewriter is not None
return self._rewriter

@rewriter.setter
def rewriter(self, rewriter: PatternRewriter):
self._rewriter = rewriter

@impl(pdl.OperationOp)
def run_operation(
self, interpreter: Interpreter, op: pdl.OperationOp, args: tuple[Any, ...]
Expand Down Expand Up @@ -216,21 +205,6 @@ def run_operation(

return (result_op,)

@impl(pdl.ResultOp)
def run_result(
self, interpreter: Interpreter, op: pdl.ResultOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
(parent,) = args
assert isinstance(parent, Operation)
return (parent.results[op.index.value.data],)

@impl(pdl.AttributeOp)
def run_attribute(
self, interpreter: Interpreter, op: pdl.AttributeOp, args: tuple[Any, ...]
) -> tuple[Any, ...]:
assert isinstance(op.value, Attribute)
return (op.value,)

@impl(pdl.ReplaceOp)
def run_replace(
self, interpreter: Interpreter, op: pdl.ReplaceOp, args: tuple[Any, ...]
Expand Down

0 comments on commit 5111b75

Please sign in to comment.