diff --git a/unitary/alpha/quantum_world.py b/unitary/alpha/quantum_world.py index 6235b908..dcfcb527 100644 --- a/unitary/alpha/quantum_world.py +++ b/unitary/alpha/quantum_world.py @@ -20,7 +20,8 @@ from unitary.alpha.sparse_vector_simulator import PostSelectOperation, SparseSimulator from unitary.alpha.qudit_state_transform import qudit_to_qubit_unitary, num_bits import numpy as np -import itertools +from itertools import combinations +import pandas as pd class QuantumWorld: @@ -689,26 +690,103 @@ def density_matrix( 2**num_shown_qubits, 2**num_shown_qubits ) - def measure_entanglement(self, obj1: QuantumObject, obj2: QuantumObject) -> float: - """Measures the entanglement (i.e. quantum mutual information) of the two given objects. + def measure_entanglement( + self, objects: Optional[Sequence[QuantumObject]] = None + ) -> float: + """Measures the entanglement (i.e. quantum mutual information) of the given objects. See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula. Parameters: - obj1, obj2: two quantum objects (currently only qubits are supported) + objects: quantum objects among which the entanglement will be calculated + (currently only qubits are supported). If not specified, all current + quantum objects will be used. If specified, at least two quantum + objects are expected. Returns: - The quantum mutual information defined as S_1 + S_2 - S_12, where S denotes (reduced) - von Neumann entropy. + The quantum mutual information. For 2 qubits it's defined as S_1 + S_2 - S_12, + where S denotes (reduced) von Neumann entropy. """ - density_matrix_12 = self.density_matrix([obj1, obj2]).reshape(2, 2, 2, 2) - density_matrix_1 = cirq.partial_trace(density_matrix_12, [0]) - density_matrix_2 = cirq.partial_trace(density_matrix_12, [1]) - return ( - cirq.von_neumann_entropy(density_matrix_1, validate=False) - + cirq.von_neumann_entropy(density_matrix_2, validate=False) - - cirq.von_neumann_entropy(density_matrix_12.reshape(4, 4), validate=False) + num_involved_objects = ( + len(objects) if objects is not None else len(self.object_name_dict.values()) ) + if num_involved_objects < 2: + raise ValueError( + f"Could not calculate entanglement for {num_involved_objects} qubit. " + "At least 2 qubits are required." + ) + + involved_objects = ( + objects if objects is not None else list(self.object_name_dict.values()) + ) + + density_matrix = self.density_matrix(involved_objects) + reshaped_density_matrix = density_matrix.reshape((2, 2) * num_involved_objects) + result = 0.0 + for comb in combinations(range(num_involved_objects), num_involved_objects - 1): + reshaped_partial_density_matrix = cirq.partial_trace( + reshaped_density_matrix, list(comb) + ) + partial_density_matrix = reshaped_partial_density_matrix.reshape( + 2 ** (num_involved_objects - 1), 2 ** (num_involved_objects - 1) + ) + result += cirq.von_neumann_entropy(partial_density_matrix, validate=False) + result -= cirq.von_neumann_entropy(density_matrix, validate=False) + return result + + def print_entanglement_table(self, count: int = 1000) -> None: + """Peek the current quantum world `count` times, and calculate pair-wise entanglement + (i.e. quantum mutual information) for each pair of quantum objects. + See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula. And print + the results out in a table. + + Parameters: + count: Number of measurements. + """ + objects = list(self.object_name_dict.values()) + num_qubits = len(objects) + if num_qubits < 2: + raise ValueError( + f"There is only {num_qubits} qubit in the quantum world. " + "At least 2 qubits are required to calculate entanglements." + ) + # Peek the current world `count` times and get the results. + histogram = self.get_correlated_histogram(objects, count) + + # Get an estimate of the state vector. + state_vector = np.array([0.0] * (2**num_qubits)) + for key, val in histogram.items(): + state_vector += self.__to_state_vector__(key) * np.sqrt(val * 1.0 / count) + density_matrix = np.outer(state_vector, state_vector) + reshaped_density_matrix = density_matrix.reshape((2, 2) * num_qubits) + + entropy = [0.0] * num_qubits + entropy_pair = np.zeros((num_qubits, num_qubits)) + entanglement = np.zeros((num_qubits, num_qubits)) + for i in range(num_qubits - 1): + for j in range(i + 1, num_qubits): + density_matrix_ij = cirq.partial_trace(reshaped_density_matrix, [i, j]) + entropy_pair[i][j] = cirq.von_neumann_entropy( + density_matrix_ij.reshape(4, 4), validate=False + ) + if i == 0: + # Fill in entropy [0] + if j == i + 1: + density_matrix_i = cirq.partial_trace(density_matrix_ij, [0]) + entropy[i] = cirq.von_neumann_entropy( + density_matrix_i, validate=False + ) + # Fill in entropy [1 to num_qubit - 1] + density_matrix_j = cirq.partial_trace(density_matrix_ij, [1]) + entropy[j] = cirq.von_neumann_entropy( + density_matrix_j, validate=False + ) + entanglement[i][j] = entropy[i] + entropy[j] - entropy_pair[i][j] + entanglement[j][i] = entanglement[i][j] + names = list(self.object_name_dict.keys()) + data_frame = pd.DataFrame(entanglement, index=names, columns=names) + print(data_frame.round(1)) + def __getitem__(self, name: str) -> QuantumObject: quantum_object = self.object_name_dict.get(name, None) if not quantum_object: diff --git a/unitary/alpha/quantum_world_test.py b/unitary/alpha/quantum_world_test.py index f8880c9d..274137a7 100644 --- a/unitary/alpha/quantum_world_test.py +++ b/unitary/alpha/quantum_world_test.py @@ -23,6 +23,8 @@ import unitary.alpha as alpha import unitary.alpha.qudit_gates as qudit_gates +import io +import contextlib class Light(enum.Enum): @@ -206,7 +208,6 @@ def test_unhook(simulator, compile_to_qubits): alpha.Split()(light, light2, light3) board.unhook(light2) results = board.peek([light2, light3], count=200, convert_to_enum=False) - print(results) assert all(result[0] == 0 for result in results) assert not all(result[1] == 0 for result in results) assert not all(result[1] == 1 for result in results) @@ -662,8 +663,6 @@ def test_combine_worlds(simulator, compile_to_qubits): results = world2.peek(count=100) expected = [StopLight.YELLOW] + result - print(results) - print(expected) assert all(actual == expected for actual in results) @@ -958,9 +957,13 @@ def test_measure_entanglement(simulator, compile_to_qubits): ) # S_1 + S_2 - S_12 = 0 + 0 - 0 = 0 for all three cases. - assert round(board.measure_entanglement(light1, light2)) == 0.0 - assert round(board.measure_entanglement(light1, light3)) == 0.0 - assert round(board.measure_entanglement(light2, light3)) == 0.0 + assert round(board.measure_entanglement([light1, light2]), 1) == 0.0 + assert round(board.measure_entanglement([light1, light3]), 1) == 0.0 + assert round(board.measure_entanglement([light2, light3]), 1) == 0.0 + # S_12 + S_13 + S_23 - S_123 = 0 + 0 + 0 - 0 = 0 + assert round(board.measure_entanglement([light1, light2, light3]), 1) == 0.0 + # Test with objects=None. + assert round(board.measure_entanglement(), 1) == 0.0 alpha.Superposition()(light2) alpha.quantum_if(light2).apply(alpha.Flip())(light3) @@ -968,8 +971,66 @@ def test_measure_entanglement(simulator, compile_to_qubits): assert not all(result[0] == 0 for result in results) assert (result[0] == result[1] for result in results) # S_1 + S_2 - S_12 = 0 + 1 - 1 = 0 - assert round(board.measure_entanglement(light1, light2), 1) == 0.0 - # S_1 + S_2 - S_12 = 0 + 1 - 1 = 0 - assert round(board.measure_entanglement(light1, light3), 1) == 0.0 - # S_1 + S_2 - S_12 = 1 + 1 - 0 = 2 - assert round(board.measure_entanglement(light2, light3), 1) == 2.0 + assert round(board.measure_entanglement([light1, light2]), 1) == 0.0 + # S_1 + S_3 - S_13 = 0 + 1 - 1 = 0 + assert round(board.measure_entanglement([light1, light3]), 1) == 0.0 + # S_2 + S_3 - S_23 = 1 + 1 - 0 = 2 + assert round(board.measure_entanglement([light2, light3]), 1) == 2.0 + # S_12 + S_13 + S_23 - S_123 = 1 + 1 + 0 - 0 + assert round(board.measure_entanglement([light1, light2, light3]), 1) == 2.0 + # Test with objects=None. + assert round(board.measure_entanglement(), 1) == 2.0 + # Supplying one object would return a value error. + with pytest.raises( + ValueError, match="Could not calculate entanglement for 1 qubit." + ): + board.measure_entanglement([light1]) + + +@pytest.mark.parametrize( + ("simulator", "compile_to_qubits"), + [ + (cirq.Simulator, False), + (cirq.Simulator, True), + # Cannot use SparseSimulator without `compile_to_qubits` due to issue #78. + (alpha.SparseSimulator, True), + ], +) +def test_print_entanglement_table(simulator, compile_to_qubits): + rho_green = np.reshape([0, 0, 0, 1], (2, 2)) + rho_red = np.reshape([1, 0, 0, 0], (2, 2)) + light1 = alpha.QuantumObject("red1", Light.RED) + light2 = alpha.QuantumObject("green", Light.GREEN) + light3 = alpha.QuantumObject("red2", Light.RED) + board = alpha.QuantumWorld( + [light1, light2, light3], + sampler=simulator(), + compile_to_qubits=compile_to_qubits, + ) + f = io.StringIO() + with contextlib.redirect_stdout(f): + board.print_entanglement_table() + assert ( + f.getvalue() + in """ + red1 green red2 +red1 0.0 0.0 0.0 +green 0.0 0.0 0.0 +red2 0.0 0.0 0.0 + """ + ) + + alpha.Superposition()(light2) + alpha.quantum_if(light2).apply(alpha.Flip())(light3) + f = io.StringIO() + with contextlib.redirect_stdout(f): + board.print_entanglement_table() + assert ( + f.getvalue() + in """ + red1 green red2 +red1 0.0 0.0 0.0 +green 0.0 0.0 2.0 +red2 0.0 2.0 0.0 +""" + )