Skip to content

Commit

Permalink
core: Add results directive
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 27, 2024
1 parent 9bc6dc1 commit e46d8ac
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 4 deletions.
187 changes: 187 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,193 @@ class OptionalResultOp(IRDLOperation):
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"program",
[
"%0 = test.results_directive : i32",
"%0, %1 = test.results_directive : i32, i32",
"%0, %1, %2 = test.results_directive : i32, i32, i32",
],
)
def test_results_directive(program: str):
"""Test the results directive"""

@irdl_op_definition
class ResultsDirectiveOp(IRDLOperation):
name = "test.results_directive"

res1 = result_def()
res2 = var_result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ResultsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program",
[
"%0 = test.results_directive : i32",
"%0, %1 = test.results_directive : i32, i32",
],
)
def test_results_directive_with_optional(program: str):
"""Test the results directive with an optional result"""

@irdl_op_definition
class ResultsDirectiveOp(IRDLOperation):
name = "test.results_directive"

res1 = opt_result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ResultsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


def test_results_directive_fails_with_two_var():
"""Test results directive cannot be used with two variadic results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' is ambiguous with multiple variadic results",
):

@irdl_op_definition
class TwoVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_var_op"

res1 = var_result_def()
res2 = var_result_def()

irdl_options = [AttrSizedResultSegments()]

assembly_format = "attr-dict `:` type(results)"


def test_results_directive_fails_with_no_results():
"""Test results directive cannot be used with no results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' should not be used when there are no results",
):

@irdl_op_definition
class NoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_results_op"

assembly_format = "attr-dict `:` type(results)"


def test_results_directive_fails_with_other_type_directive():
"""Test results directive cannot be used with no results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' cannot be used in a type directive with other result type directives",
):

@irdl_op_definition
class TwoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_results_op"

res1 = result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type($res1) `,` type(results)"


@pytest.mark.parametrize(
"program, error",
[
("%0 = test.two_results : i32", "Expected 2 result types but found 1"),
(
"%0, %1, %2 = test.two_results : i32, i32, i32",
"Expected 2 result types but found 3",
),
],
)
def test_results_directive_bounds(program: str, error: str):
@irdl_op_definition
class TwoResultsOp(IRDLOperation):
name = "test.two_results"

res1 = result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(TwoResultsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


@pytest.mark.parametrize(
"program, error",
[
(
"%0 = test.three_results : i32",
"Expected at least 2 result types but found 1",
),
(
"%0, %1, %2, %3 = test.three_results : i32, i32, i32, i32",
"Expected at most 3 result types but found 4",
),
],
)
def test_results_directive_bounds_with_opt(program: str, error: str):
@irdl_op_definition
class ThreeResultsOp(IRDLOperation):
name = "test.three_results"

res1 = result_def()
res2 = opt_result_def()
res3 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ThreeResultsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


def test_results_directive_bound_with_var():
@irdl_op_definition
class ThreeResultsOp(IRDLOperation):
name = "test.three_results"

res1 = result_def()
res2 = opt_result_def()
res3 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ThreeResultsOp)

with pytest.raises(
ParseError, match="Expected at least 2 result types but found 1"
):
parser = Parser(ctx, "%0 = test.three_results : i32")
parser.parse_operation()


################################################################################
# Regions #
################################################################################
Expand Down
51 changes: 47 additions & 4 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,9 @@ def set_types_empty(self, state: ParsingState) -> None:


@dataclass(frozen=True)
class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective):
class OperandsOrResultDirective(VariadicTypeableDirective, ABC):
"""
An operands directive, with the following format:
operands-directive ::= operands
Prints each operand of the operation, inserting a comma between each.
Base class for the 'operands' and 'results' directives.
"""

variadic_index: tuple[bool, int] | None
Expand Down Expand Up @@ -639,6 +637,14 @@ def _set_using_variadic_index(
field[var_position] = set_to[var_position : var_position + var_length]
field[var_position + 1 :] = set_to[var_position + var_length :]


class OperandsDirective(VariadicOperandDirective, OperandsOrResultDirective):
"""
An operands directive, with the following format:
operands-directive ::= operands
Prints each operand of the operation, inserting a comma between each.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
operands = (
Expand Down Expand Up @@ -768,6 +774,43 @@ def set_types_empty(self, state: ParsingState) -> None:
state.result_types[self.index] = ()


class ResultsDirective(OperandsOrResultDirective):
"""
A results directive, with the following format:
results-directive ::= results
A typeable directive which processes the result types of the operation.
"""

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.result_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.result_types[0] = parser.parse_type()

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
types = (
parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
or []
)

if s := self._set_using_variadic_index(
state.result_types, "result types", types
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(types)

def set_types_empty(self, state: ParsingState) -> None:
state.result_types = [() for _ in state.operand_types]

def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
return op.result_types

def is_present(self, op: IRDLOperation) -> bool:
return bool(op.results)


class RegionDirective(OptionallyParsableDirective, ABC):
"""
Baseclass to help keep typechecking simple.
Expand Down
22 changes: 22 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PunctuationDirective,
RegionDirective,
RegionVariable,
ResultsDirective,
ResultVariable,
SuccessorVariable,
TypeableDirective,
Expand Down Expand Up @@ -702,6 +703,8 @@ def parse_typeable_directive(self) -> TypeableDirective:
"""
if self.parse_optional_keyword("operands"):
return self.create_operands_directive(False)
if self.parse_optional_keyword("results"):
return self.create_results_directive()
if variable := self.parse_optional_typeable_variable():
return variable
self.raise_error(f"unexpected token '{self._current_token.text}'")
Expand Down Expand Up @@ -773,3 +776,22 @@ def create_operands_directive(self, top_level: bool) -> OperandsDirective:
if not variadics:
return OperandsDirective(None)
return OperandsDirective(variadics[0])

def create_results_directive(self) -> ResultsDirective:
if not self.op_def.results:
self.raise_error("'results' should not be used when there are no results")
if any(self.seen_result_types):
self.raise_error(
"'results' cannot be used in a type directive with other result type directives"
)
variadics = tuple(
(isinstance(o, OptResultDef), i)
for i, (_, o) in enumerate(self.op_def.results)
if isinstance(o, VarResultDef)
)
if len(variadics) > 1:
self.raise_error("'results' is ambiguous with multiple variadic results")
self.seen_result_types = [True] * len(self.seen_result_types)
if not variadics:
return ResultsDirective(None)
return ResultsDirective(variadics[0])

0 comments on commit e46d8ac

Please sign in to comment.