Skip to content

Commit

Permalink
Merge branch 'main' into u/eliottrosenbrg/xeb_batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
eliottrosenberg authored Nov 22, 2024
2 parents b39e1b4 + e0087b0 commit 557d971
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 16 deletions.
48 changes: 32 additions & 16 deletions cirq-core/cirq/circuits/qasm_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Utility classes for representing QASM."""

from typing import Callable, Dict, Iterator, Optional, Sequence, Set, Tuple, Union, TYPE_CHECKING
from typing import Callable, Dict, Iterator, Optional, Sequence, Tuple, Union, TYPE_CHECKING

import re
import numpy as np
Expand Down Expand Up @@ -203,6 +203,7 @@ def __init__(
qubit_id_map=qubit_id_map,
meas_key_id_map=meas_key_id_map,
)
self.cregs = self._generate_cregs()

def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[str]]]:
# Pick an id for the creg that will store each measurement
Expand All @@ -226,6 +227,30 @@ def _generate_measurement_ids(self) -> Tuple[Dict[str, str], Dict[str, Optional[
def _generate_qubit_ids(self) -> Dict['cirq.Qid', str]:
return {qubit: f'q[{i}]' for i, qubit in enumerate(self.qubits)}

def _generate_cregs(self) -> Dict[str, tuple[int, str]]:
"""Pick an id for the creg that will store each measurement
This function finds the largest measurement using each key.
That is, if multiple measurements are made with the same key,
it will use the key with the most number of qubits.
Returns: dictionary with key of measurement id and value of (#qubits, comment).
"""
cregs: Dict[str, tuple[int, str]] = {}
for meas in self.measurements:
key = protocols.measurement_key_name(meas)
meas_id = self.args.meas_key_id_map[key]

if self.meas_comments[key] is not None:
comment = f' // Measurement: {self.meas_comments[key]}'
else:
comment = ''

if meas_id not in cregs or cregs[meas_id][0] < len(meas.qubits):
cregs[meas_id] = (len(meas.qubits), comment)

return cregs

def is_valid_qasm_id(self, id_str: str) -> bool:
"""Test if id_str is a valid id in QASM grammar."""
return self.valid_id_re.match(id_str) is not None
Expand Down Expand Up @@ -287,24 +312,15 @@ def output(text):
output(f'qreg q[{len(self.qubits)}];\n')
else:
output(f'qubit[{len(self.qubits)}] q;\n')
# Classical registers
# Pick an id for the creg that will store each measurement
already_output_keys: Set[str] = set()
for meas in self.measurements:
key = protocols.measurement_key_name(meas)
if key in already_output_keys:
continue
already_output_keys.add(key)
meas_id = self.args.meas_key_id_map[key]
if self.meas_comments[key] is not None:
comment = f' // Measurement: {self.meas_comments[key]}'
else:
comment = ''

# Classical registers
for meas_id in self.cregs:
length, comment = self.cregs[meas_id]
if self.args.version == '2.0':
output(f'creg {meas_id}[{len(meas.qubits)}];{comment}\n')
output(f'creg {meas_id}[{length}];{comment}\n')
else:
output(f'bit[{len(meas.qubits)}] {meas_id};{comment}\n')
output(f'bit[{length}] {meas_id};{comment}\n')

# In OpenQASM 2.0, the transformation of global phase gates is ignored.
# Therefore, no newline is created when the operations contained in
# a circuit consist only of global phase gates.
Expand Down
55 changes: 55 additions & 0 deletions cirq-core/cirq/circuits/qasm_output_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,58 @@ def test_reset():
reset q[1];
""".strip()
)


def test_different_sized_registers():
qubits = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.measure(qubits[0], key='c'), cirq.measure(qubits, key='c'))
output = cirq.QasmOutput(
c.all_operations(), tuple(sorted(c.all_qubits())), header='Generated from Cirq!'
)
assert (
str(output)
== """// Generated from Cirq!
OPENQASM 2.0;
include "qelib1.inc";
// Qubits: [q(0), q(1)]
qreg q[2];
creg m_c[2];
measure q[0] -> m_c[0];
// Gate: cirq.MeasurementGate(2, cirq.MeasurementKey(name='c'), ())
measure q[0] -> m_c[0];
measure q[1] -> m_c[1];
"""
)
# OPENQASM 3.0
output3 = cirq.QasmOutput(
c.all_operations(),
tuple(sorted(c.all_qubits())),
header='Generated from Cirq!',
version='3.0',
)
assert (
str(output3)
== """// Generated from Cirq!
OPENQASM 3.0;
include "stdgates.inc";
// Qubits: [q(0), q(1)]
qubit[2] q;
bit[2] m_c;
m_c[0] = measure q[0];
// Gate: cirq.MeasurementGate(2, cirq.MeasurementKey(name='c'), ())
m_c[0] = measure q[0];
m_c[1] = measure q[1];
"""
)

0 comments on commit 557d971

Please sign in to comment.