Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow measure_entanglement to calculate entanglement for multiple qubits #206

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 91 additions & 13 deletions unitary/alpha/quantum_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from unitary.alpha.sparse_vector_simulator import PostSelectOperation, SparseSimulator
from unitary.alpha.qudit_state_transform import qudit_to_qubit_unitary, num_bits
import numpy as np
import itertools
from itertools import combinations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer import itertools then itertools.combinations below

"Use import statements for packages and modules only, not for individual types, classes, or functions."
https://google.github.io/styleguide/pyguide.html#22-imports

import pandas as pd


class QuantumWorld:
Expand Down Expand Up @@ -689,26 +690,103 @@ def density_matrix(
2**num_shown_qubits, 2**num_shown_qubits
)

def measure_entanglement(self, obj1: QuantumObject, obj2: QuantumObject) -> float:
"""Measures the entanglement (i.e. quantum mutual information) of the two given objects.
def measure_entanglement(
self, objects: Optional[Sequence[QuantumObject]] = None
) -> float:
"""Measures the entanglement (i.e. quantum mutual information) of the given objects.
See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula.

Parameters:
obj1, obj2: two quantum objects (currently only qubits are supported)
objects: quantum objects among which the entanglement will be calculated
(currently only qubits are supported). If not specified, all current
quantum objects will be used. If specified, at least two quantum
objects are expected.

Returns:
The quantum mutual information defined as S_1 + S_2 - S_12, where S denotes (reduced)
von Neumann entropy.
The quantum mutual information. For 2 qubits it's defined as S_1 + S_2 - S_12,
where S denotes (reduced) von Neumann entropy.
"""
density_matrix_12 = self.density_matrix([obj1, obj2]).reshape(2, 2, 2, 2)
density_matrix_1 = cirq.partial_trace(density_matrix_12, [0])
density_matrix_2 = cirq.partial_trace(density_matrix_12, [1])
return (
cirq.von_neumann_entropy(density_matrix_1, validate=False)
+ cirq.von_neumann_entropy(density_matrix_2, validate=False)
- cirq.von_neumann_entropy(density_matrix_12.reshape(4, 4), validate=False)
num_involved_objects = (
len(objects) if objects is not None else len(self.object_name_dict.values())
)

if num_involved_objects < 2:
raise ValueError(
f"Could not calculate entanglement for {num_involved_objects} qubit. "
"At least 2 qubits are required."
)

involved_objects = (
objects if objects is not None else list(self.object_name_dict.values())
)

density_matrix = self.density_matrix(involved_objects)
reshaped_density_matrix = density_matrix.reshape((2, 2) * num_involved_objects)
result = 0.0
for comb in combinations(range(num_involved_objects), num_involved_objects - 1):
reshaped_partial_density_matrix = cirq.partial_trace(
reshaped_density_matrix, list(comb)
)
partial_density_matrix = reshaped_partial_density_matrix.reshape(
2 ** (num_involved_objects - 1), 2 ** (num_involved_objects - 1)
)
result += cirq.von_neumann_entropy(partial_density_matrix, validate=False)
result -= cirq.von_neumann_entropy(density_matrix, validate=False)
return result

def print_entanglement_table(self, count: int = 1000) -> None:
"""Peek the current quantum world `count` times, and calculate pair-wise entanglement
(i.e. quantum mutual information) for each pair of quantum objects.
See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula. And print
the results out in a table.

Parameters:
count: Number of measurements.
"""
objects = list(self.object_name_dict.values())
num_qubits = len(objects)
if num_qubits < 2:
raise ValueError(
f"There is only {num_qubits} qubit in the quantum world. "
"At least 2 qubits are required to calculate entanglements."
)
# Peek the current world `count` times and get the results.
histogram = self.get_correlated_histogram(objects, count)

# Get an estimate of the state vector.
state_vector = np.array([0.0] * (2**num_qubits))
for key, val in histogram.items():
state_vector += self.__to_state_vector__(key) * np.sqrt(val * 1.0 / count)
density_matrix = np.outer(state_vector, state_vector)
reshaped_density_matrix = density_matrix.reshape((2, 2) * num_qubits)

entropy = [0.0] * num_qubits
entropy_pair = np.zeros((num_qubits, num_qubits))
entanglement = np.zeros((num_qubits, num_qubits))
for i in range(num_qubits - 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this block is fairly tough to follow. I would add comments below or in the function's docstring to explain the implementation details.

for j in range(i + 1, num_qubits):
density_matrix_ij = cirq.partial_trace(reshaped_density_matrix, [i, j])
entropy_pair[i][j] = cirq.von_neumann_entropy(
density_matrix_ij.reshape(4, 4), validate=False
)
if i == 0:
# Fill in entropy [0]
if j == i + 1:
density_matrix_i = cirq.partial_trace(density_matrix_ij, [0])
entropy[i] = cirq.von_neumann_entropy(
density_matrix_i, validate=False
)
# Fill in entropy [1 to num_qubit - 1]
density_matrix_j = cirq.partial_trace(density_matrix_ij, [1])
entropy[j] = cirq.von_neumann_entropy(
density_matrix_j, validate=False
)
entanglement[i][j] = entropy[i] + entropy[j] - entropy_pair[i][j]
entanglement[j][i] = entanglement[i][j]
names = list(self.object_name_dict.keys())
data_frame = pd.DataFrame(entanglement, index=names, columns=names)
print(data_frame.round(1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we return the data_frame (or some other result) instead of printing it?
Printing it seems generally unhelpful if we are using this within a game.


def __getitem__(self, name: str) -> QuantumObject:
quantum_object = self.object_name_dict.get(name, None)
if not quantum_object:
Expand Down
83 changes: 72 additions & 11 deletions unitary/alpha/quantum_world_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import unitary.alpha as alpha
import unitary.alpha.qudit_gates as qudit_gates
import io
import contextlib


class Light(enum.Enum):
Expand Down Expand Up @@ -206,7 +208,6 @@ def test_unhook(simulator, compile_to_qubits):
alpha.Split()(light, light2, light3)
board.unhook(light2)
results = board.peek([light2, light3], count=200, convert_to_enum=False)
print(results)
assert all(result[0] == 0 for result in results)
assert not all(result[1] == 0 for result in results)
assert not all(result[1] == 1 for result in results)
Expand Down Expand Up @@ -662,8 +663,6 @@ def test_combine_worlds(simulator, compile_to_qubits):

results = world2.peek(count=100)
expected = [StopLight.YELLOW] + result
print(results)
print(expected)
assert all(actual == expected for actual in results)


Expand Down Expand Up @@ -958,18 +957,80 @@ def test_measure_entanglement(simulator, compile_to_qubits):
)

# S_1 + S_2 - S_12 = 0 + 0 - 0 = 0 for all three cases.
assert round(board.measure_entanglement(light1, light2)) == 0.0
assert round(board.measure_entanglement(light1, light3)) == 0.0
assert round(board.measure_entanglement(light2, light3)) == 0.0
assert round(board.measure_entanglement([light1, light2]), 1) == 0.0
assert round(board.measure_entanglement([light1, light3]), 1) == 0.0
assert round(board.measure_entanglement([light2, light3]), 1) == 0.0
# S_12 + S_13 + S_23 - S_123 = 0 + 0 + 0 - 0 = 0
assert round(board.measure_entanglement([light1, light2, light3]), 1) == 0.0
# Test with objects=None.
assert round(board.measure_entanglement(), 1) == 0.0

alpha.Superposition()(light2)
alpha.quantum_if(light2).apply(alpha.Flip())(light3)
results = board.peek([light2, light3], count=100)
assert not all(result[0] == 0 for result in results)
assert (result[0] == result[1] for result in results)
# S_1 + S_2 - S_12 = 0 + 1 - 1 = 0
assert round(board.measure_entanglement(light1, light2), 1) == 0.0
# S_1 + S_2 - S_12 = 0 + 1 - 1 = 0
assert round(board.measure_entanglement(light1, light3), 1) == 0.0
# S_1 + S_2 - S_12 = 1 + 1 - 0 = 2
assert round(board.measure_entanglement(light2, light3), 1) == 2.0
assert round(board.measure_entanglement([light1, light2]), 1) == 0.0
# S_1 + S_3 - S_13 = 0 + 1 - 1 = 0
assert round(board.measure_entanglement([light1, light3]), 1) == 0.0
# S_2 + S_3 - S_23 = 1 + 1 - 0 = 2
assert round(board.measure_entanglement([light2, light3]), 1) == 2.0
# S_12 + S_13 + S_23 - S_123 = 1 + 1 + 0 - 0
assert round(board.measure_entanglement([light1, light2, light3]), 1) == 2.0
# Test with objects=None.
assert round(board.measure_entanglement(), 1) == 2.0
# Supplying one object would return a value error.
with pytest.raises(
ValueError, match="Could not calculate entanglement for 1 qubit."
):
board.measure_entanglement([light1])


@pytest.mark.parametrize(
("simulator", "compile_to_qubits"),
[
(cirq.Simulator, False),
(cirq.Simulator, True),
# Cannot use SparseSimulator without `compile_to_qubits` due to issue #78.
(alpha.SparseSimulator, True),
],
)
def test_print_entanglement_table(simulator, compile_to_qubits):
rho_green = np.reshape([0, 0, 0, 1], (2, 2))
rho_red = np.reshape([1, 0, 0, 0], (2, 2))
light1 = alpha.QuantumObject("red1", Light.RED)
light2 = alpha.QuantumObject("green", Light.GREEN)
light3 = alpha.QuantumObject("red2", Light.RED)
board = alpha.QuantumWorld(
[light1, light2, light3],
sampler=simulator(),
compile_to_qubits=compile_to_qubits,
)
f = io.StringIO()
with contextlib.redirect_stdout(f):
board.print_entanglement_table()
assert (
f.getvalue()
in """
red1 green red2
red1 0.0 0.0 0.0
green 0.0 0.0 0.0
red2 0.0 0.0 0.0
"""
)

alpha.Superposition()(light2)
alpha.quantum_if(light2).apply(alpha.Flip())(light3)
f = io.StringIO()
with contextlib.redirect_stdout(f):
board.print_entanglement_table()
assert (
f.getvalue()
in """
red1 green red2
red1 0.0 0.0 0.0
green 0.0 0.0 2.0
red2 0.0 2.0 0.0
"""
)
Loading