Skip to content

Commit

Permalink
dialects: (stream) use assembly format for stream read and write ops (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Nov 20, 2024
1 parent 680b556 commit f7088a0
Showing 1 changed file with 4 additions and 40 deletions.
44 changes: 4 additions & 40 deletions xdsl/dialects/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import abc
from typing import ClassVar, Generic, TypeVar, cast

from typing_extensions import Self

from xdsl.dialects.builtin import ContainerType
from xdsl.ir import (
Attribute,
Expand All @@ -25,8 +23,6 @@
operand_def,
result_def,
)
from xdsl.parser import Parser
from xdsl.printer import Printer

_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True)
_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute)
Expand Down Expand Up @@ -101,28 +97,15 @@ class ReadOperation(IRDLOperation, abc.ABC):
stream = operand_def(ReadableStreamType.constr(T))
res = result_def(T)

assembly_format = "`from` $stream attr-dict `:` type($res)"

def __init__(self, stream: SSAValue, result_type: Attribute | None = None):
if result_type is None:
assert isinstance(stream_type := stream.type, ReadableStreamType)
stream_type = cast(ReadableStreamType[Attribute], stream_type)
result_type = stream_type.element_type
super().__init__(operands=[stream], result_types=[result_type])

@classmethod
def parse(cls, parser: Parser) -> Self:
parser.parse_characters("from")
unresolved = parser.parse_unresolved_operand()
parser.parse_punctuation(":")
result_type = parser.parse_attribute()
resolved = parser.resolve_operand(unresolved, ReadableStreamType(result_type))
return cls(resolved, result_type)

def print(self, printer: Printer):
printer.print_string(" from ")
printer.print(self.stream)
printer.print_string(" : ")
printer.print_attribute(self.res.type)


class WriteOperation(IRDLOperation, abc.ABC):
"""
Expand All @@ -134,30 +117,11 @@ class WriteOperation(IRDLOperation, abc.ABC):
value = operand_def(T)
stream = operand_def(WritableStreamType.constr(T))

assembly_format = "$value `to` $stream attr-dict `:` type($value)"

def __init__(self, value: SSAValue, stream: SSAValue):
super().__init__(operands=[value, stream])

@classmethod
def parse(cls, parser: Parser) -> Self:
unresolved_value = parser.parse_unresolved_operand()
parser.parse_characters("to")
unresolved_stream = parser.parse_unresolved_operand()
parser.parse_punctuation(":")
result_type = parser.parse_attribute()
resolved_value = parser.resolve_operand(unresolved_value, result_type)
resolved_stream = parser.resolve_operand(
unresolved_stream, WritableStreamType(result_type)
)
return cls(resolved_value, resolved_stream)

def print(self, printer: Printer):
printer.print_string(" ")
printer.print_ssa_value(self.value)
printer.print_string(" to ")
printer.print_ssa_value(self.stream)
printer.print_string(" : ")
printer.print_attribute(self.value.type)


Stream = Dialect(
"stream",
Expand Down

0 comments on commit f7088a0

Please sign in to comment.