Skip to content

Commit

Permalink
Remove dialects and other changes that do not belong in the liveness …
Browse files Browse the repository at this point in the history
…analysis PR
  • Loading branch information
gabrielrodcanal committed Oct 30, 2024
1 parent f1a66a7 commit 1c98ac9
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 97 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,3 @@ docs/ret.xdsl
# direnv
.direnv
.envrc

xdsl-venv/
16 changes: 1 addition & 15 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
from xdsl.traits import IsContraction, IsTerminator
from xdsl.traits import IsTerminator
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.str_enum import StrEnum
Expand Down Expand Up @@ -834,8 +834,6 @@ class MatmulOp(NamedOpBase):

PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True

traits = frozenset([IsContraction()])

def __init__(
self,
inputs: Sequence[SSAValue],
Expand Down Expand Up @@ -999,17 +997,6 @@ class PoolingNchwMaxOp(PoolingOpsBase):
name = "linalg.pooling_nchw_max"


@irdl_op_definition
class PoolingNchwSumOp(PoolingOpsBase):
"""
Performs sum pooling
See https://mlir.llvm.org/docs/Dialects/Linalg/#linalgpooling_nchw_sum-linalgpoolingnchwsumop
"""

name = "linalg.pooling_nchw_sum"


class ConvOpsBase(IRDLOperation, ABC):
"""Base class for linalg convolution operations."""

Expand Down Expand Up @@ -1183,7 +1170,6 @@ def parse(cls, parser: Parser) -> Self:
MatmulOp,
QuantizedMatmulOp,
PoolingNchwMaxOp,
PoolingNchwSumOp,
Conv2DNchwFchwOp,
BroadcastOp,
],
Expand Down
71 changes: 0 additions & 71 deletions xdsl/dialects/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
IntegerType,
TensorType,
UnrankedTensorType,
i32,
i64,
)
from xdsl.dialects.utils import AbstractYieldOperation
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AttrSizedOperandSegments,
Expand Down Expand Up @@ -436,73 +434,6 @@ def from_static_parameters(
)


@irdl_op_definition
class Yield(AbstractYieldOperation[Attribute]):
name = "tensor.yield"


@irdl_op_definition
class PadOp(IRDLOperation):
name = "tensor.pad"
source = operand_def(TensorType)
# low = var_operand_def(IndexType)
# high = var_operand_def(IndexType)
static_low = prop_def(DenseArrayBase)
static_high = prop_def(DenseArrayBase)

result = result_def(TensorType)

irdl_options = [AttrSizedOperandSegments(as_property=True)]

def __init__(
self,
source: SSAValue | Operation,
low: Sequence[IndexType],
high: Sequence[IndexType],
):
source = SSAValue.get(source)
assert isinstance(source.type, TensorType)

new_shape = []
for dim_idx, dim in enumerate(source.type.shape.data):
new_dim = dim.data + low[dim_idx] + high[dim_idx]
new_shape.append(new_dim)

return_type = TensorType(source.type.element_type, new_shape)

super().__init__(
operands=[source],
properties={
"static_low": DenseArrayBase.from_list(i32, low),
"static_high": DenseArrayBase.from_list(i32, high),
},
result_types=[return_type],
)

@classmethod
def parse(cls, parser: Parser) -> Self:
source = parser.parse_operand()
parser.parse_keyword("low")
low = parser.parse_comma_separated_list(
Parser.Delimiter.SQUARE, parser.parse_integer
)
parser.parse_keyword("high")
high = parser.parse_comma_separated_list(
Parser.Delimiter.SQUARE, parser.parse_integer
)
parser.parse_region()
attrs = parser.parse_optional_attr_dict()
parser.parse_punctuation(":")
parser.parse_type()
parser.parse_keyword("to")
parser.parse_type()

pad = cls(source, low, high)
pad.attributes |= attrs

return pad


Tensor = Dialect(
"tensor",
[
Expand All @@ -513,8 +444,6 @@ def parse(cls, parser: Parser) -> Self:
InsertSliceOp,
ReshapeOp,
CollapseShapeOp,
PadOp,
Yield,
],
[],
)
1 change: 0 additions & 1 deletion xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,6 @@ def emit_error(

diagnostic = Diagnostic()
diagnostic.add_message(self, message)
print("OPERATION ERROR: ", self)
diagnostic.raise_exception(message, self, exception_type, underlying_error)

@classmethod
Expand Down
3 changes: 1 addition & 2 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ def verify(
# constraint context is not modified.
constraint_context_copy = constraint_context.copy()
try:
# GABRIEL 22/10/2024: this verification fails with matmul in ResNet (HIDA test)
# attr_constr.verify(attr, constraint_context_copy)
attr_constr.verify(attr, constraint_context_copy)
# If the constraint succeeds, we update back the constraint variables
constraint_context.update(constraint_context_copy)
return
Expand Down
6 changes: 0 additions & 6 deletions xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def verify(self, op: Operation) -> None:
OpTraitInvT = TypeVar("OpTraitInvT", bound=OpTrait)


class IsContraction(OpTrait):
"""
Temporary patch to qualify operations as contractions. This is done with an interface for the linalg dialect.
"""


class ConstantLike(OpTrait):
"""
Operation known to be constant-like.
Expand Down

0 comments on commit 1c98ac9

Please sign in to comment.