Skip to content

Commit

Permalink
Add extra test and fix operands/results directives
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 27, 2024
1 parent 940bc1a commit 83b2fa4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
29 changes: 29 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,35 @@ class FunctionalTypeOp(IRDLOperation):
ctx.load_op(FunctionalTypeOp)


@pytest.mark.parametrize(
"program",
[
"%0 = test.functional_type %1, %2 : (i32, i32) -> i32",
"%0, %1 = test.functional_type %2, %3 : (i32, i32) -> (i32, i32)",
"%0 = test.functional_type %1 : (i32) -> i32",
],
)
def test_functional_type_with_operands_and_results(program: str):
"""
Test the parsing of the functional-type directive using the operands and
results directives
"""

@irdl_op_definition
class FunctionalTypeOp(IRDLOperation):
name = "test.functional_type"

op1 = operand_def()
ops2 = var_operand_def()
res1 = var_result_def()
res2 = result_def()

assembly_format = "operands attr-dict `:` functional-type(operands, results)"

ctx = MLContext()
ctx.load_op(FunctionalTypeOp)


################################################################################
# Regions #
################################################################################
Expand Down
16 changes: 10 additions & 6 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,9 +660,11 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
return bool(operands)

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.operand_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.operand_types[0] = parser.parse_type()
pos_start = parser.pos
if s := self._set_using_variadic_index(
state.operand_types, "operand types", (parser.parse_type(),)
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
Expand Down Expand Up @@ -782,9 +784,11 @@ class ResultsDirective(OperandsOrResultDirective):
"""

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.result_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.result_types[0] = parser.parse_type()
pos_start = parser.pos
if s := self._set_using_variadic_index(
state.result_types, "result types", (parser.parse_type(),)
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
Expand Down

0 comments on commit 83b2fa4

Please sign in to comment.