Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (stim) Add parser for .stim formatted strings #3122

Merged
merged 40 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e71d99a
Add stim dialect
kimxworrall Aug 15, 2024
838d11e
Add tests
kimxworrall Aug 15, 2024
82ae376
core: CanonicaliZation naming consistency (#3040)
PapyChacal Aug 15, 2024
22c247f
update stim
kimxworrall Aug 27, 2024
f3f785a
Pull region assembly format update by merging branch 'main' into kim/…
kimxworrall Aug 28, 2024
e6c2ebe
initialise stim print and parse, reorganise files to have stim relate…
kimxworrall Aug 28, 2024
06662bc
fix tests
kimxworrall Aug 28, 2024
32585bc
fix precommit
kimxworrall Aug 28, 2024
2cfbf1a
Update __init__.py
kimxworrall Aug 28, 2024
a5c8b34
Update tests/filecheck/dialects/stim/stim_ops.mlir
kimxworrall Aug 28, 2024
5e99704
Update tests/filecheck/dialects/stim/stim_ops.mlir
kimxworrall Aug 28, 2024
123328a
Add qubit coordinates annotation and attributes.
kimxworrall Aug 29, 2024
bf56513
add tests for stim printer
kimxworrall Aug 29, 2024
1c4e0a1
Clean tests and add qubit coordinates printer
kimxworrall Aug 29, 2024
a63cb4d
Remove unnecessary files
kimxworrall Aug 29, 2024
94ce681
dialects: (stim) Add qubit attribute and qubit coordinate attribute
kimxworrall Aug 29, 2024
febe161
Remove StimOp reference
kimxworrall Aug 29, 2024
ff2b0c0
Align with precommit
kimxworrall Aug 29, 2024
2428317
Remove unused test functions
kimxworrall Aug 29, 2024
dc1f1b3
align with precommit
kimxworrall Aug 29, 2024
57e4162
Re-Add print_lists
kimxworrall Aug 29, 2024
7f9cc6b
dialects: (stim) Add first annotation op
kimxworrall Aug 29, 2024
91fa1c0
Fix syntax errors
kimxworrall Aug 29, 2024
7069931
Apply suggestions from code review
kimxworrall Aug 30, 2024
ae824ba
Add tests for stim printer parser
kimxworrall Aug 30, 2024
bcd210c
move syntax tests to filecheck
kimxworrall Aug 30, 2024
6c74405
replace stimattr with stimprintable
kimxworrall Aug 30, 2024
fb5d65e
steramline print functions
kimxworrall Aug 30, 2024
7d3ae2f
dialects: (stim) Add first annotation op
kimxworrall Aug 29, 2024
b1ea8c6
Fix syntax errors
kimxworrall Aug 29, 2024
e1f6519
Merge branch 'kim/stim/first-annotation-op' of https://github.com/xds…
kimxworrall Aug 30, 2024
beea9c7
align with first-attributes updates
kimxworrall Aug 30, 2024
b1a389a
add filecheck tests for assign_qubit_coord op
kimxworrall Aug 30, 2024
efcf706
align with precommit
kimxworrall Aug 30, 2024
c0acb9f
reremove stimop in favour of having circuitop inherit stimprintable
kimxworrall Aug 30, 2024
b02284a
Add stim parser and update printer to print operations. Also add test…
kimxworrall Aug 31, 2024
bd71f1c
Align with precommit
kimxworrall Aug 31, 2024
0f59338
Add comments and match parsing of parens to stim parser
kimxworrall Sep 1, 2024
3c1ef8c
align with main
kimxworrall Oct 2, 2024
d7e2c44
Merge branch 'main' into kim/stim/stim-printer-parser
superlopuh Nov 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion tests/dialects/stim/test_stim_printer_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pytest

from xdsl.dialects import stim
from xdsl.dialects.stim.ops import QubitAttr, QubitCoordsOp, QubitMappingAttr
from xdsl.dialects.stim.ops import (
QubitAttr,
QubitCoordsOp,
QubitMappingAttr,
)
from xdsl.dialects.stim.stim_parser import StimParseError, StimParser
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.dialects.test import TestOp
from xdsl.ir import Block, Region
Expand All @@ -20,6 +25,19 @@ def check_stim_print(program: StimPrintable, expected_stim: str):
assert expected_stim == res_io.getvalue()


def check_stim_roundtrip(program: str):
"""Check that the given program roundtrips exactly (including whitespaces)."""
stim_parser = StimParser(program)
stim_circuit = stim_parser.parse_circuit()

check_stim_print(stim_circuit, program)


################################################################################
# Test operations stim_print() #
################################################################################


def test_empty_circuit():
empty_block = Block()
empty_region = Region(empty_block)
Expand Down Expand Up @@ -60,3 +78,45 @@ def test_print_stim_qubit_coord_op():
qubit_annotation = QubitCoordsOp(qubit_coord)
expected_stim = "QUBIT_COORDS(0, 0) 0"
check_stim_print(qubit_annotation, expected_stim)


################################################################################
# Test stim parser and printer #
################################################################################


@pytest.mark.parametrize(
"program",
[(""), ("\n"), ("#hi"), ("# hi \n" "#hi\n")],
)
def test_stim_roundtrip_empty_circuit(program: str):
stim_parser = StimParser(program)
stim_circuit = stim_parser.parse_circuit()
check_stim_print(stim_circuit, "")


@pytest.mark.parametrize(
"program",
[
("QUBIT_COORDS() 0\n"),
("QUBIT_COORDS(0, 0) 0\n"),
("QUBIT_COORDS(0, 2) 1\n"),
("QUBIT_COORDS(0, 0) 0\n" "QUBIT_COORDS(1, 2) 2\n"),
],
)
def test_stim_roundtrip_qubit_coord_op(program: str):
check_stim_roundtrip(program)


def test_no_spaces_before_target():
with pytest.raises(StimParseError, match="Targets must be separated by spacing."):
program = "QUBIT_COORDS(1, 1)1"
parser = StimParser(program)
parser.parse_circuit()


def test_no_targets():
program = "QUBIT_COORDS(1, 1)"
with pytest.raises(StimParseError, match="Expected at least one target"):
parser = StimParser(program)
parser.parse_circuit()
24 changes: 15 additions & 9 deletions xdsl/dialects/stim/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from io import StringIO

from xdsl.dialects.builtin import ArrayAttr, IntAttr
from xdsl.dialects.builtin import ArrayAttr, FloatData, IntAttr
from xdsl.dialects.stim.stim_printer_parser import StimPrintable, StimPrinter
from xdsl.ir import ParametrizedAttribute, Region, TypeAttribute
from xdsl.irdl import (
Expand Down Expand Up @@ -63,26 +63,33 @@ class QubitMappingAttr(StimPrintable, ParametrizedAttribute):

name = "stim.qubit_coord"

coords: ParameterDef[ArrayAttr[IntAttr]]
coords: ParameterDef[ArrayAttr[FloatData | IntAttr]]
qubit_name: ParameterDef[QubitAttr]

def __init__(
self, coords: list[int] | ArrayAttr[IntAttr], qubit_name: int | QubitAttr
self,
coords: list[float] | ArrayAttr[FloatData | IntAttr],
qubit_name: int | QubitAttr,
) -> None:
if not isinstance(qubit_name, QubitAttr):
qubit_name = QubitAttr(qubit_name)
if not isinstance(coords, ArrayAttr):
coords = ArrayAttr(IntAttr(c) for c in coords)
coords = ArrayAttr(
(IntAttr(int(arg))) if (type(arg) is int) else (FloatData(arg))
for arg in coords
)
super().__init__(parameters=[coords, qubit_name])

@classmethod
def parse_parameters(
cls, parser: AttrParser
) -> tuple[ArrayAttr[IntAttr], QubitAttr]:
) -> tuple[ArrayAttr[FloatData | IntAttr], QubitAttr]:
parser.parse_punctuation("<")
coords = parser.parse_comma_separated_list(
delimiter=parser.Delimiter.PAREN,
parse=lambda: IntAttr(parser.parse_integer(allow_boolean=False)),
parse=lambda: IntAttr(x)
if type(x := parser.parse_number(allow_boolean=False)) is int
else FloatData(x),
)
parser.parse_punctuation(",")
qubit = parser.parse_attribute()
Expand Down Expand Up @@ -128,9 +135,8 @@ def verify(self, verify_nested_ops: bool = True) -> None:

def print_stim(self, printer: StimPrinter):
for op in self.body.block.ops:
if not isinstance(op, StimPrintable):
raise ValueError(f"Cannot print in stim format: {op}")
op.print_stim(printer)
printer.print_op(op)
printer.print_string("\n")
printer.print_string("")

def stim(self) -> str:
Expand Down
Loading
Loading