Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Add results directive #3518

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,193 @@ class OptionalResultOp(IRDLOperation):
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"program",
[
"%0 = test.results_directive : i32",
"%0, %1 = test.results_directive : i32, i32",
"%0, %1, %2 = test.results_directive : i32, i32, i32",
],
)
def test_results_directive(program: str):
"""Test the results directive"""

@irdl_op_definition
class ResultsDirectiveOp(IRDLOperation):
name = "test.results_directive"

res1 = result_def()
res2 = var_result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ResultsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program",
[
"%0 = test.results_directive : i32",
"%0, %1 = test.results_directive : i32, i32",
],
)
def test_results_directive_with_optional(program: str):
"""Test the results directive with an optional result"""

@irdl_op_definition
class ResultsDirectiveOp(IRDLOperation):
name = "test.results_directive"

res1 = opt_result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ResultsDirectiveOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)


def test_results_directive_fails_with_two_var():
"""Test results directive cannot be used with two variadic results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' is ambiguous with multiple variadic results",
):

@irdl_op_definition
class TwoVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_var_op"

res1 = var_result_def()
res2 = var_result_def()

irdl_options = [AttrSizedResultSegments()]

assembly_format = "attr-dict `:` type(results)"


def test_results_directive_fails_with_no_results():
"""Test results directive cannot be used with no results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' should not be used when there are no results",
):

@irdl_op_definition
class NoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_results_op"

assembly_format = "attr-dict `:` type(results)"


def test_results_directive_fails_with_other_type_directive():
"""Test results directive cannot be used with no results"""

with pytest.raises(
PyRDLOpDefinitionError,
match="'results' cannot be used in a type directive with other result type directives",
):

@irdl_op_definition
class TwoResultsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.two_results_op"

res1 = result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type($res1) `,` type(results)"


@pytest.mark.parametrize(
"program, error",
[
("%0 = test.two_results : i32", "Expected 2 result types but found 1"),
(
"%0, %1, %2 = test.two_results : i32, i32, i32",
"Expected 2 result types but found 3",
),
],
)
def test_results_directive_bounds(program: str, error: str):
@irdl_op_definition
class TwoResultsOp(IRDLOperation):
name = "test.two_results"

res1 = result_def()
res2 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(TwoResultsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


@pytest.mark.parametrize(
"program, error",
[
(
"%0 = test.three_results : i32",
"Expected at least 2 result types but found 1",
),
(
"%0, %1, %2, %3 = test.three_results : i32, i32, i32, i32",
"Expected at most 3 result types but found 4",
),
],
)
def test_results_directive_bounds_with_opt(program: str, error: str):
@irdl_op_definition
class ThreeResultsOp(IRDLOperation):
name = "test.three_results"

res1 = result_def()
res2 = opt_result_def()
res3 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ThreeResultsOp)

with pytest.raises(ParseError, match=error):
parser = Parser(ctx, program)
parser.parse_operation()


def test_results_directive_bound_with_var():
@irdl_op_definition
class ThreeResultsOp(IRDLOperation):
name = "test.three_results"

res1 = result_def()
res2 = opt_result_def()
res3 = result_def()

assembly_format = "attr-dict `:` type(results)"

ctx = MLContext()
ctx.load_op(ThreeResultsOp)

with pytest.raises(
ParseError, match="Expected at least 2 result types but found 1"
):
parser = Parser(ctx, "%0 = test.three_results : i32")
parser.parse_operation()


################################################################################
# Regions #
################################################################################
Expand Down
51 changes: 47 additions & 4 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,9 @@ def set_types_empty(self, state: ParsingState) -> None:


@dataclass(frozen=True)
class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective):
class OperandsOrResultDirective(VariadicTypeableDirective, ABC):
"""
An operands directive, with the following format:
operands-directive ::= operands
Prints each operand of the operation, inserting a comma between each.
Base class for the 'operands' and 'results' directives.
"""

variadic_index: tuple[bool, int] | None
Expand Down Expand Up @@ -639,6 +637,14 @@ def _set_using_variadic_index(
field[var_position] = set_to[var_position : var_position + var_length]
field[var_position + 1 :] = set_to[var_position + var_length :]


class OperandsDirective(VariadicOperandDirective, OperandsOrResultDirective):
"""
An operands directive, with the following format:
operands-directive ::= operands
Prints each operand of the operation, inserting a comma between each.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
operands = (
Expand Down Expand Up @@ -768,6 +774,43 @@ def set_types_empty(self, state: ParsingState) -> None:
state.result_types[self.index] = ()


class ResultsDirective(OperandsOrResultDirective):
"""
A results directive, with the following format:
results-directive ::= results
A typeable directive which processes the result types of the operation.
"""

def parse_single_type(self, parser: Parser, state: ParsingState) -> None:
if len(state.result_types) > 1:
parser.raise_error("Expected multiple types but received one.")
state.result_types[0] = parser.parse_type()

def parse_many_types(self, parser: Parser, state: ParsingState) -> bool:
pos_start = parser.pos
types = (
parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
or []
)

if s := self._set_using_variadic_index(
state.result_types, "result types", types
):
parser.raise_error(s, at_position=pos_start, end_position=parser.pos)
return bool(types)

def set_types_empty(self, state: ParsingState) -> None:
state.result_types = [() for _ in state.operand_types]

def get_types(self, op: IRDLOperation) -> Sequence[Attribute]:
return op.result_types

def is_present(self, op: IRDLOperation) -> bool:
return bool(op.results)


class RegionDirective(OptionallyParsableDirective, ABC):
"""
Baseclass to help keep typechecking simple.
Expand Down
22 changes: 22 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PunctuationDirective,
RegionDirective,
RegionVariable,
ResultsDirective,
ResultVariable,
SuccessorVariable,
TypeableDirective,
Expand Down Expand Up @@ -702,6 +703,8 @@ def parse_typeable_directive(self) -> TypeableDirective:
"""
if self.parse_optional_keyword("operands"):
return self.create_operands_directive(False)
if self.parse_optional_keyword("results"):
return self.create_results_directive()
if variable := self.parse_optional_typeable_variable():
return variable
self.raise_error(f"unexpected token '{self._current_token.text}'")
Expand Down Expand Up @@ -773,3 +776,22 @@ def create_operands_directive(self, top_level: bool) -> OperandsDirective:
if not variadics:
return OperandsDirective(None)
return OperandsDirective(variadics[0])

def create_results_directive(self) -> ResultsDirective:
if not self.op_def.results:
self.raise_error("'results' should not be used when there are no results")
if any(self.seen_result_types):
self.raise_error(
"'results' cannot be used in a type directive with other result type directives"
)
variadics = tuple(
(isinstance(o, OptResultDef), i)
for i, (_, o) in enumerate(self.op_def.results)
if isinstance(o, VarResultDef)
)
if len(variadics) > 1:
self.raise_error("'results' is ambiguous with multiple variadic results")
self.seen_result_types = [True] * len(self.seen_result_types)
if not variadics:
return ResultsDirective(None)
return ResultsDirective(variadics[0])