Skip to content

Commit

Permalink
core: Fix 'parse_single_type' for operands/results directives (#3553)
Browse files Browse the repository at this point in the history
Reworks the way 'parse_single_type' works for operands/results
directives, reusing the infrastructure introduced for
'parse_many_types'. This now allows 'parse_single_type' to work when
there are variadic operands, as demonstrated by the new tests.

The new tests manually make a `FormatProgram`, as the format program
parser will never generate such a program.

This currently makes no difference, as these functions are never called,
but it will be important for the functional-type directive (#3517), as
it allows a type like:
```mlir
(i32, i32) -> i32
```
to be parsed, even if the results of the operation is a variadic (the
functional type directive calls 'parse_single_type' when the results are
not wrapped in parentheses).
  • Loading branch information
alexarice authored Dec 2, 2024
1 parent 3d23c8f commit 52431ee
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 6 deletions.
148 changes: 148 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@
var_result_def,
var_successor_def,
)
from xdsl.irdl.declarative_assembly_format import (
AttrDictDirective,
FormatProgram,
OperandsDirective,
PunctuationDirective,
ResultsDirective,
TypeDirective,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError, VerifyException
Expand Down Expand Up @@ -1254,6 +1262,77 @@ class ThreeOperandsOp(IRDLOperation):
parser.parse_operation()


def test_operands_directive_with_non_variadic_type_directive():
"""Tests the 'parse_single_type' function of the operands directive."""

# The parser will never generate a non-variadic TypeDirective containing
# an OperandsDirective, but we can manually make one.
format_program = FormatProgram(
(
OperandsDirective(None),
AttrDictDirective(False, set(), False),
PunctuationDirective(":"),
TypeDirective(OperandsDirective(None)),
),
{},
)

@irdl_op_definition
class OneOperandOp(IRDLOperation):
name = "test.one_operand"

op1 = operand_def()

@classmethod
def parse(cls, parser: Parser) -> OneOperandOp:
return format_program.parse(parser, cls)

def print(self, printer: Printer):
format_program.print(printer, self)

ctx = MLContext()
ctx.load_op(OneOperandOp)

check_roundtrip("test.one_operand %0 : i32", ctx)


def test_operands_directive_with_variadic_type_directive():
"""
Tests the 'parse_single_type' function of the operands directive
when the operation has a variadic.
"""
# The parser will never generate a non-variadic TypeDirective containing
# an OperandsDirective, but we can manually make one.
format_program = FormatProgram(
(
OperandsDirective((False, 1)),
AttrDictDirective(False, set(), False),
PunctuationDirective(":"),
TypeDirective(OperandsDirective((False, 1))),
),
{},
)

@irdl_op_definition
class TwoOperandOp(IRDLOperation):
name = "test.two_operand"

op1 = operand_def()
op2 = var_operand_def()

@classmethod
def parse(cls, parser: Parser) -> TwoOperandOp:
return format_program.parse(parser, cls)

def print(self, printer: Printer):
format_program.print(printer, self)

ctx = MLContext()
ctx.load_op(TwoOperandOp)

check_roundtrip("test.two_operand %0 : i32", ctx)


################################################################################
# Results #
################################################################################
Expand Down Expand Up @@ -1610,6 +1689,75 @@ class ThreeResultsOp(IRDLOperation):
parser.parse_operation()


def test_results_directive_with_non_variadic_type_directive():
"""Tests the 'parse_single_type' function of the results directive."""

# The parser will never generate a non-variadic TypeDirective containing
# a ResultsDirective, but we can manually make one.
format_program = FormatProgram(
(
AttrDictDirective(False, set(), False),
PunctuationDirective(":"),
TypeDirective(ResultsDirective(None)),
),
{},
)

@irdl_op_definition
class OneResultOp(IRDLOperation):
name = "test.one_result"

res = result_def()

@classmethod
def parse(cls, parser: Parser) -> OneResultOp:
return format_program.parse(parser, cls)

def print(self, printer: Printer):
format_program.print(printer, self)

ctx = MLContext()
ctx.load_op(OneResultOp)

check_roundtrip("%0 = test.one_result : i32", ctx)


def test_results_directive_with_variadic_type_directive():
"""
Tests the 'parse_single_type' function of the results directive
when the operation has a variadic.
"""
# The parser will never generate a non-variadic TypeDirective containing
# a ResultsDirective, but we can manually make one.
format_program = FormatProgram(
(
AttrDictDirective(False, set(), False),
PunctuationDirective(":"),
TypeDirective(ResultsDirective((False, 1))),
),
{},
)

@irdl_op_definition
class TwoResultsOp(IRDLOperation):
name = "test.two_results"

res1 = result_def()
res2 = var_result_def()

@classmethod
def parse(cls, parser: Parser) -> TwoResultsOp:
return format_program.parse(parser, cls)

def print(self, printer: Printer):
format_program.print(printer, self)

ctx = MLContext()
ctx.load_op(TwoResultsOp)

check_roundtrip("%0 = test.two_results : i32", ctx)


################################################################################
# Regions #
################################################################################
Expand Down
16 changes: 10 additions & 6 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,9 +665,11 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
return bool(operands)

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.operand_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.operand_types[0] = parser.parse_type()
pos_start = parser.pos
if s := self._set_using_variadic_index(
state.operand_types, "operand types", (parser.parse_type(),)
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
Expand Down Expand Up @@ -787,9 +789,11 @@ class ResultsDirective(OperandsOrResultDirective):
"""

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.result_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.result_types[0] = parser.parse_type()
pos_start = parser.pos
if s := self._set_using_variadic_index(
state.result_types, "result types", (parser.parse_type(),)
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
Expand Down

0 comments on commit 52431ee

Please sign in to comment.