Skip to content

Commit

Permalink
Merge pull request #1431 from qiboteam/measurement_result_serialization
Browse files Browse the repository at this point in the history
Serialization of measurements
  • Loading branch information
scarrazza authored Sep 4, 2024
2 parents 47cb206 + 7864590 commit 4f2314a
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/qibo/backends/clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def execute_circuit_repeated(self, circuit, nshots: int = 1000, initial_state=No
samples = self.np.vstack(samples)

for meas in circuit.measurements:
meas.result.register_samples(samples[:, meas.target_qubits], self)
meas.result.register_samples(samples[:, meas.target_qubits])

result = Clifford(
self.zero_state(circuit.nqubits),
Expand Down
15 changes: 14 additions & 1 deletion src/qibo/gates/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
"_target_qubits",
"_control_qubits",
]
REQUIRED_FIELDS_INIT_KWARGS = ["theta", "phi", "lam", "phi0", "phi1"]
REQUIRED_FIELDS_INIT_KWARGS = [
"theta",
"phi",
"lam",
"phi0",
"phi1",
"register_name",
"collapse",
"basis",
"p0",
"p1",
]


class Gate:
Expand Down Expand Up @@ -107,6 +118,8 @@ def from_dict(raw: dict):
raise ValueError(f"Unknown gate {raw['_class']}")

gate = cls(*raw["init_args"], **raw["init_kwargs"])
if raw["_class"] == "M" and raw["measurement_result"]["samples"] is not None:
gate.result.register_samples(raw["measurement_result"]["samples"])
try:
return gate.controlled_by(*raw["_control_qubits"])
except RuntimeError as e:
Expand Down
71 changes: 37 additions & 34 deletions src/qibo/gates/measurements.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union

from qibo import gates
from qibo.config import raise_error
Expand All @@ -23,11 +23,13 @@ class M(Gate):
performed. Can be used only for single shot measurements.
If ``True`` the collapsed state vector is returned. If ``False``
the measurement result is returned.
basis (:class:`qibo.gates.Gate`, list): Basis to measure.
Can be a qibo gate or a callable that accepts a qubit,
for example: ``lambda q: gates.RX(q, 0.2)``
or a list of these, if a different basis will be used for each
measurement qubit.
basis (:class:`qibo.gates.Gate`, str, list): Basis to measure.
Can be either:
- a qibo gate
- the string representing the gate
- a callable that accepts a qubit, for example: ``lambda q: gates.RX(q, 0.2)``
- a list of the above, if a different basis will be used for each
measurement qubit.
Default is Z.
p0 (dict): Optional bitflip probability map. Can be:
A dictionary that maps each measured qubit to the probability
Expand All @@ -46,7 +48,7 @@ def __init__(
*q,
register_name: Optional[str] = None,
collapse: bool = False,
basis: Gate = Z,
basis: Union[Gate, str] = Z,
p0: Optional["ProbsType"] = None,
p1: Optional["ProbsType"] = None,
):
Expand All @@ -61,15 +63,24 @@ def __init__(
# relevant for experiments only
self.pulses = None
# saving basis for __repr__ ans save to file
to_gate = lambda x: getattr(gates, x) if isinstance(x, str) else x
if not isinstance(basis, list):
self.basis_gates = len(q) * [basis]
self.basis_gates = len(q) * [to_gate(basis)]
basis = len(self.target_qubits) * [basis]
elif len(basis) != len(self.target_qubits):
raise_error(
ValueError,
f"Given basis list has length {len(basis)} while "
f"we are measuring {len(self.target_qubits)} qubits.",
)
else:
self.basis_gates = basis
self.basis_gates = [to_gate(g) for g in basis]

self.init_args = q
self.init_kwargs = {
"register_name": register_name,
"collapse": collapse,
"basis": [g.__name__ for g in self.basis_gates],
"p0": p0,
"p1": p1,
}
Expand All @@ -88,20 +99,25 @@ def __init__(

# list of gates that will be added to the circuit before the
# measurement, in order to rotate to the given basis
if not isinstance(basis, list):
basis = len(self.target_qubits) * [basis]
elif len(basis) != len(self.target_qubits):
raise_error(
ValueError,
f"Given basis list has length {len(basis)} while "
f"we are measuring {len(self.target_qubits)} qubits.",
)
self.basis = []
for qubit, basis_cls in zip(self.target_qubits, basis):
for qubit, basis_cls in zip(self.target_qubits, self.basis_gates):
gate = basis_cls(qubit).basis_rotation()
if gate is not None:
self.basis.append(gate)

@property
def raw(self) -> dict:
"""Serialize to dictionary.
The values used in the serialization should be compatible with a
JSON dump (or any other one supporting a minimal set of scalar
types). Though the specific implementation is up to the specific
gate.
"""
encoded_simple = super().raw
encoded_simple.update({"measurement_result": self.result.raw})
return encoded_simple

@staticmethod
def _get_bitflip_tuple(
qubits: Tuple[int, ...], probs: "ProbsType"
Expand Down Expand Up @@ -178,7 +194,7 @@ def apply(self, backend, state, nqubits):
qubits = sorted(self.target_qubits)
# measure and get result
probs = backend.calculate_probabilities(state, qubits, nqubits)
shot = self.result.add_shot(probs)
shot = self.result.add_shot(probs, backend=backend)
# collapse state
return backend.collapse_state(state, qubits, shot, nqubits)

Expand All @@ -190,7 +206,7 @@ def apply_density_matrix(self, backend, state, nqubits):
qubits = sorted(self.target_qubits)
# measure and get result
probs = backend.calculate_probabilities_density_matrix(state, qubits, nqubits)
shot = self.result.add_shot(probs)
shot = self.result.add_shot(probs, backend=backend)
# collapse state
return backend.collapse_density_matrix(state, qubits, shot, nqubits)

Expand All @@ -204,25 +220,12 @@ def apply_clifford(self, backend, state, nqubits):
self.result.add_shot_from_sample(sample[0])
return state

def to_json(self):
"""Serializes the measurement gate to json."""
encoding = json.loads(super().to_json())
encoding.pop("_control_qubits")
encoding.update({"basis": [g.__name__ for g in self.basis_gates]})
return json.dumps(encoding)

@classmethod
def load(cls, payload):
"""Constructs a measurement gate starting from a json serialized
one."""
args = json.loads(payload)
# drop general serialization data, unused in this specialized loader
for key in ("name", "init_args", "_class"):
args.pop(key)
qubits = args.pop("_target_qubits")
args["basis"] = [getattr(gates, g) for g in args["basis"]]
args.update(args.pop("init_kwargs"))
return cls(*qubits, **args)
return cls.from_dict(args)

# Overload on_qubits to copy also gate.result, controlled by can be removed for measurements
def on_qubits(self, qubit_map) -> "Gate":
Expand Down
48 changes: 32 additions & 16 deletions src/qibo/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from qibo.config import raise_error


def _check_backend(backend):
"""This is only needed due to the circular import with qibo.backends."""
from qibo.backends import _check_backend

return _check_backend(backend)


def frequencies_to_binary(frequencies, nqubits):
return collections.Counter(
{"{:b}".format(k).zfill(nqubits): v for k, v in frequencies.items()}
Expand Down Expand Up @@ -79,10 +86,8 @@ class MeasurementResult:
to use for calculations.
"""

def __init__(self, gate, nshots=0, backend=None):
def __init__(self, gate):
self.measurement_gate = gate
self.backend = backend
self.nshots = nshots
self.circuit = None

self._samples = None
Expand All @@ -96,36 +101,45 @@ def __repr__(self):
nshots = self.nshots
return f"MeasurementResult(qubits={qubits}, nshots={nshots})"

def add_shot(self, probs):
@property
def raw(self) -> dict:
samples = self._samples.tolist() if self.has_samples() else self._samples
return {"samples": samples}

@property
def nshots(self) -> int:
if self.has_samples():
return len(self._samples)
elif self._frequencies is not None:
return sum(self._frequencies.values())

def add_shot(self, probs, backend=None):
backend = _check_backend(backend)
qubits = sorted(self.measurement_gate.target_qubits)
shot = self.backend.sample_shots(probs, 1)
bshot = self.backend.samples_to_binary(shot, len(qubits))
shot = backend.sample_shots(probs, 1)
bshot = backend.samples_to_binary(shot, len(qubits))
if self._samples:
self._samples.append(bshot[0])
else:
self._samples = [bshot[0]]
self.nshots += 1
return shot

def add_shot_from_sample(self, sample):
if self._samples:
self._samples.append(sample)
else:
self._samples = [sample]
self.nshots += 1

def has_samples(self):
return self._samples is not None

def register_samples(self, samples, backend=None):
def register_samples(self, samples):
"""Register samples array to the ``MeasurementResult`` object."""
self._samples = samples
self.nshots = len(samples)

def register_frequencies(self, frequencies, backend=None):
def register_frequencies(self, frequencies):
"""Register frequencies to the ``MeasurementResult`` object."""
self._frequencies = frequencies
self.nshots = sum(frequencies.values())

def reset(self):
"""Remove all registered samples and frequencies."""
Expand All @@ -144,7 +158,7 @@ def symbols(self):

return self._symbols

def samples(self, binary=True, registers=False):
def samples(self, binary=True, registers=False, backend=None):
"""Returns raw measurement samples.
Args:
Expand All @@ -159,6 +173,7 @@ def samples(self, binary=True, registers=False):
samples are returned in decimal form as a tensor
of shape `(nshots,)`.
"""
backend = _check_backend(backend)
if self._samples is None:
if self.circuit is None:
raise_error(
Expand All @@ -172,9 +187,9 @@ def samples(self, binary=True, registers=False):
return self._samples

qubits = self.measurement_gate.target_qubits
return self.backend.samples_to_decimal(self._samples, len(qubits))
return backend.samples_to_decimal(self._samples, len(qubits))

def frequencies(self, binary=True, registers=False):
def frequencies(self, binary=True, registers=False, backend=None):
"""Returns the frequencies of measured samples.
Args:
Expand All @@ -192,8 +207,9 @@ def frequencies(self, binary=True, registers=False):
If `binary` is `False`
the keys of the `Counter` are integers.
"""
backend = _check_backend(backend)
if self._frequencies is None:
self._frequencies = self.backend.calculate_frequencies(
self._frequencies = backend.calculate_frequencies(
self.samples(binary=False)
)
if binary:
Expand Down
2 changes: 1 addition & 1 deletion src/qibo/quantum_info/clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def samples(self, binary: bool = True, registers: bool = False):
self._samples = self._backend.cast(samples, dtype="int32")
for gate in self.measurements:
rqubits = tuple(qubit_map.get(q) for q in gate.target_qubits)
gate.result.register_samples(self._samples[:, rqubits], self._backend)
gate.result.register_samples(self._samples[:, rqubits])

if registers:
return {
Expand Down
6 changes: 2 additions & 4 deletions src/qibo/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def frequencies(self, binary: bool = True, registers: bool = False):
if int(bitstring[qubit_map.get(q)]):
idx += 2 ** (len(rqubits) - i - 1)
rfreqs[idx] += freq
gate.result.register_frequencies(rfreqs, self.backend)
gate.result.register_frequencies(rfreqs)
else:
self._frequencies = self.backend.calculate_frequencies(
self.samples(binary=False)
Expand Down Expand Up @@ -356,9 +356,7 @@ def samples(self, binary: bool = True, registers: bool = False):
self._samples = samples
for gate in self.measurements:
rqubits = tuple(qubit_map.get(q) for q in gate.target_qubits)
gate.result.register_samples(
self._samples[:, rqubits], self.backend
)
gate.result.register_samples(self._samples[:, rqubits])

if registers:
return {
Expand Down
40 changes: 40 additions & 0 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Test circuit result measurements and measurement gate and as part of circuit."""

import json
import pickle

import numpy as np
import pytest

from qibo import gates, models
from qibo.measurements import MeasurementResult


def assert_result(
Expand Down Expand Up @@ -473,3 +475,41 @@ def test_measurementsymbol_pickling(backend):
assert symbol.index == new_symbol.index
assert symbol.name == new_symbol.name
backend.assert_allclose(symbol.result.samples(), new_symbol.result.samples())


def test_measurementresult_nshots(backend):
gate = gates.M(*range(3))
result = MeasurementResult(gate)
# nshots starting from samples
nshots = 10
samples = backend.cast(
[[i % 2, i % 2, i % 2] for i in range(nshots)], backend.np.int64
)
result.register_samples(samples)
assert result.nshots == nshots
# nshots starting from frequencies
result = MeasurementResult(gate)
states, counts = np.unique(samples, axis=0, return_counts=True)
to_str = lambda x: [str(item) for item in x]
states = ["".join(to_str(s)) for s in states.tolist()]
freq = dict(zip(states, counts.tolist()))
result.register_frequencies(freq)
assert result.nshots == nshots


def test_measurement_serialization(backend):
kwargs = {
"register_name": "test",
"collapse": False,
"basis": ["Z", "X", "Y"],
"p0": 0.1,
"p1": 0.2,
}
gate = gates.M(*range(3), **kwargs)
samples = backend.cast(np.random.randint(2, size=(100, 3)), backend.np.int64)
gate.result.register_samples(samples)
dump = gate.to_json()
load = gates.M.from_dict(json.loads(dump))
for k, v in kwargs.items():
assert load.init_kwargs[k] == v
backend.assert_allclose(samples, load.result.samples())
Loading

0 comments on commit 4f2314a

Please sign in to comment.