Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 26, 2024
1 parent a15921a commit c924b31
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 12 deletions.
162 changes: 161 additions & 1 deletion tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ class OptionalResultOp(IRDLOperation):
],
)
def test_results_directive(program: str):
"""Test the operands directive"""
"""Test the results directive"""

@irdl_op_definition
class ResultsDirectiveOp(IRDLOperation):
Expand All @@ -1381,6 +1381,166 @@ class ResultsDirectiveOp(IRDLOperation):
check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program",
[
"%0 = test.results_directive : i32",
"%0, %1 = test.results_directive : i32, i32",
],
)
def test_results_directive_with_optional(program: str):
"""Test the results directive with an optional result"""

@irdl_op_definition
class ResultsDirectiveOp(IRDLOperation):
name = "test.results_directive"

res1 = opt_result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ResultsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


def test_results_directive_fails_with_two_var():
"""Test results directive cannot be used with two variadic results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' is ambiguous with multiple variadic results",
):

@irdl_op_definition
class TwoVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_var_op"

res1 = var_result_def()
res2 = var_result_def()

irdl_options = [AttrSizedResultSegments()]

assembly_format = "attr-dict `:` type(results)"


def test_results_directive_fails_with_no_results():
"""Test results directive cannot be used with no results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' should not be used when there are no results",
):

@irdl_op_definition
class NoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_results_op"

assembly_format = "attr-dict `:` type(results)"


def test_results_directive_fails_with_other_type_directive():
"""Test results directive cannot be used with no results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' cannot be used in a type directive with other result type directives",
):

@irdl_op_definition
class TwoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_results_op"

res1 = result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type($res1) `,` type(results)"


@pytest.mark.parametrize(
"program, error",
[
("%0 = test.two_results : i32", "Expected 2 result types but found 1"),
(
"%0, %1, %2 = test.two_results : i32, i32, i32",
"Expected 2 result types but found 3",
),
],
)
def test_results_directive_bounds(program: str, error: str):
@irdl_op_definition
class TwoResultsOp(IRDLOperation):
name = "test.two_results"

res1 = result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(TwoResultsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


@pytest.mark.parametrize(
"program, error",
[
(
"%0 = test.three_results : i32",
"Expected at least 2 result types but found 1",
),
(
"%0, %1, %2, %3 = test.three_results : i32, i32, i32, i32",
"Expected at most 3 result types but found 4",
),
],
)
def test_results_directive_bounds_with_opt(program: str, error: str):
@irdl_op_definition
class ThreeResultsOp(IRDLOperation):
name = "test.three_results"

res1 = result_def()
res2 = opt_result_def()
res3 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ThreeResultsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


def test_results_directive_bound_with_var():
@irdl_op_definition
class ThreeResultsOp(IRDLOperation):
name = "test.three_results"

res1 = result_def()
res2 = opt_result_def()
res3 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ThreeResultsOp)

with pytest.raises(
ParseError, match="Expected at least 2 result types but found 1"
):
parser = Parser(ctx, "%0 = test.three_results : i32")
parser.parse_operation()


################################################################################
# Regions #
################################################################################
Expand Down
23 changes: 14 additions & 9 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,22 +612,23 @@ class OperandsOrResultDirective(VariadicTypeableDirective, ABC):
variadic_index: tuple[bool, int] | None

def _set_using_variadic_index(
self, field: list[_T | None | Sequence[_T]], set_to: Sequence[_T]
self,
field: list[_T | None | Sequence[_T]],
field_str: str,
set_to: Sequence[_T],
) -> str | None:
if self.variadic_index is None:
if len(set_to) != len(field):
return f"Expected {len(field)} operands but found {len(set_to)}"
return f"Expected {len(field)} {field_str} but found {len(set_to)}"
field = [o for o in set_to] # Copy needed as list is not covariant
return

is_optional, var_position = self.variadic_index
var_length = len(set_to) - len(field) + 1
if var_length < 0:
return (
f"Expected at least {len(field) - 1} operands but found {len(set_to)}"
)
return f"Expected at least {len(field) - 1} {field_str} but found {len(set_to)}"
if var_length > 1 and is_optional:
return f"Expected at most {len(field)} operands but found {len(set_to)}"
return f"Expected at most {len(field)} {field_str} but found {len(set_to)}"
field[:var_position] = set_to[:var_position]
field[var_position] = set_to[var_position : var_position + var_length]
field[var_position + 1 :] = set_to[var_position + var_length :]
Expand All @@ -652,7 +653,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
or []
)

if s := self._set_using_variadic_index(state.operands, operands):
if s := self._set_using_variadic_index(state.operands, "operands", operands):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(operands)

Expand All @@ -670,7 +671,9 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
or []
)

if s := self._set_using_variadic_index(state.operand_types, types):
if s := self._set_using_variadic_index(
state.operand_types, "operand types", types
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(types)

Expand Down Expand Up @@ -792,7 +795,9 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
or []
)

if s := self._set_using_variadic_index(state.result_types, types):
if s := self._set_using_variadic_index(
state.result_types, "result types", types
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(types)

Expand Down
4 changes: 2 additions & 2 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,10 @@ def create_operands_directive(self, top_level: bool) -> OperandsDirective:

def create_results_directive(self) -> ResultsDirective:
if not self.op_def.results:
self.raise_error("'results' should not be used when there are no operands")
self.raise_error("'results' should not be used when there are no results")
if any(self.seen_result_types):
self.raise_error(
"'results' can not be used in a type directive with other result type directives"
"'results' cannot be used in a type directive with other result type directives"
)
variadics = tuple(
(isinstance(o, OptResultDef), i)
Expand Down

0 comments on commit c924b31

Please sign in to comment.