Skip to content

Commit

Permalink
rewrite fix
Browse files Browse the repository at this point in the history
  • Loading branch information
renatomello committed Feb 26, 2024
1 parent 8fe4bf2 commit 225903e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 84 deletions.
73 changes: 16 additions & 57 deletions src/qibo/quantum_info/entropies.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,14 @@ def shannon_entropy(prob_dist, base: float = 2, backend=None):
"All elements of the probability array must be between 0. and 1..",
)

total_sum = (
backend.torch.sum(prob_dist) if backend.name == "pytorch" else np.sum(prob_dist)
)
total_sum = backend.np.sum(prob_dist)

if np.abs(total_sum - 1.0) > PRECISION_TOL:
raise_error(ValueError, "Probability array must sum to 1.")

log_prob = np.where(prob_dist != 0, np.log2(prob_dist) / np.log2(base), 0.0)

shan_entropy = (
-backend.torch.sum(prob_dist * log_prob)
if backend.name == "pytorch"
else -np.sum(prob_dist * log_prob)
)
shan_entropy = -backend.np.sum(prob_dist * log_prob)

# absolute value if entropy == 0.0 to avoid returning -0.0
shan_entropy = np.abs(shan_entropy) if shan_entropy == 0.0 else shan_entropy
Expand Down Expand Up @@ -127,16 +121,10 @@ def classical_relative_entropy(prob_dist_p, prob_dist_q, base: float = 2, backen
ValueError,
"All elements of the probability array must be between 0. and 1..",
)
total_sum_p = (
backend.torch.sum(prob_dist_p)
if backend.name == "pytorch"
else np.sum(prob_dist_p)
)
total_sum_q = (
backend.torch.sum(prob_dist_q)
if backend.name == "pytorch"
else np.sum(prob_dist_q)
)
total_sum_p = backend.np.sum(prob_dist_p)

total_sum_q = backend.np.sum(prob_dist_q)

if np.abs(total_sum_p - 1.0) > PRECISION_TOL:
raise_error(ValueError, "First probability array must sum to 1.")

Expand All @@ -151,11 +139,7 @@ def classical_relative_entropy(prob_dist_p, prob_dist_q, base: float = 2, backen

log_prob = np.where(prob_dist_p != 0.0, log_prob_q, 0.0)

relative = (
backend.torch.sum(prob_dist_p * log_prob)
if backend.name == "pytorch"
else np.sum(prob_dist_p * log_prob)
)
relative = backend.np.sum(prob_dist_p * log_prob)

return entropy_p - relative

Expand Down Expand Up @@ -228,9 +212,7 @@ def classical_renyi_entropy(
"All elements of the probability array must be between 0. and 1..",
)

total_sum = (
backend.torch.sum(prob_dist) if backend.name == "pytorch" else np.sum(prob_dist)
)
total_sum = backend.np.sum(prob_dist)

if np.abs(total_sum - 1.0) > PRECISION_TOL:
raise_error(ValueError, "Probability array must sum to 1.")
Expand All @@ -244,11 +226,7 @@ def classical_renyi_entropy(
if alpha == np.inf:
return -1 * np.log2(max(prob_dist)) / np.log2(base)

total_sum = (
backend.torch.sum(prob_dist**alpha)
if backend.name == "pytorch"
else np.sum(prob_dist**alpha)
)
total_sum = backend.np.sum(prob_dist**alpha)

renyi_ent = (1 / (1 - alpha)) * np.log2(total_sum) / np.log2(base)

Expand Down Expand Up @@ -332,16 +310,8 @@ def classical_relative_renyi_entropy(
"All elements of the probability array must be between 0. and 1..",
)

total_sum_p = (
backend.torch.sum(prob_dist_p)
if backend.name == "pytorch"
else np.sum(prob_dist_p)
)
total_sum_q = (
backend.torch.sum(prob_dist_q)
if backend.name == "pytorch"
else np.sum(prob_dist_q)
)
total_sum_p = backend.np.sum(prob_dist_p)
total_sum_q = backend.np.sum(prob_dist_q)

if np.abs(total_sum_p - 1.0) > PRECISION_TOL:
raise_error(ValueError, "First probability array must sum to 1.")
Expand All @@ -351,11 +321,8 @@ def classical_relative_renyi_entropy(

if alpha == 0.5:
total_sum = np.sqrt(prob_dist_p * prob_dist_q)
total_sum = (
backend.torch.sum(total_sum)
if backend.name == "pytorch"
else np.sum(total_sum)
)
total_sum = backend.np.sum(total_sum)

return -2 * np.log2(total_sum) / np.log2(base)

if alpha == 1.0:
Expand All @@ -369,11 +336,7 @@ def classical_relative_renyi_entropy(
prob_p = prob_dist_p**alpha
prob_q = prob_dist_q ** (1 - alpha)

total_sum = (
backend.torch.sum(prob_p * prob_q)
if backend.name == "pytorch"
else np.sum(prob_p * prob_q)
)
total_sum = backend.np.sum(prob_p * prob_q)

return (1 / (alpha - 1)) * np.log2(total_sum) / np.log2(base)

Expand Down Expand Up @@ -431,9 +394,7 @@ def classical_tsallis_entropy(prob_dist, alpha: float, base: float = 2, backend=
"All elements of the probability array must be between 0. and 1..",
)

total_sum = (
backend.torch.sum(prob_dist) if backend.name == "pytorch" else np.sum(prob_dist)
)
total_sum = backend.np.sum(prob_dist)

if np.abs(total_sum - 1.0) > PRECISION_TOL:
raise_error(ValueError, "Probability array must sum to 1.")
Expand All @@ -442,9 +403,7 @@ def classical_tsallis_entropy(prob_dist, alpha: float, base: float = 2, backend=
return shannon_entropy(prob_dist, base=base, backend=backend)

total_sum = prob_dist**alpha
total_sum = (
backend.torch.sum(total_sum) if backend.name == "pytorch" else np.sum(total_sum)
)
total_sum = backend.np.sum(total_sum)

return (1 / (1 - alpha)) * (total_sum - 1)

Expand Down
4 changes: 1 addition & 3 deletions src/qibo/quantum_info/quantum_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,7 @@ def _set_tensor_and_parameters(self):
"""Sets tensor based on inputs."""
self._backend = _check_backend(self._backend)

self._einsum = (
self._backend.torch.einsum if self._backend.name == "pytorch" else np.einsum
)
self._einsum = self._backend.np.einsum

if isinstance(self.partition, list):
self.partition = tuple(self.partition)
Expand Down
6 changes: 2 additions & 4 deletions src/qibo/quantum_info/random_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,14 +1178,12 @@ def _super_op_from_bcsz_measure(dims: int, rank: int, order: str, seed, backend)
for eigenvalue, eigenvector in zip(eigenvalues, np.transpose(eigenvectors)):
operator += eigenvalue * np.outer(eigenvector, np.conj(eigenvector))

kron = backend.torch.kron if backend.name == "pytorch" else np.kron

if order == "row":
operator = kron(
operator = backend.np.kron(
backend.identity_density_matrix(nqubits, normalize=False), operator
)
if order == "column":
operator = kron(
operator = backend.np.kron(
operator, backend.identity_density_matrix(nqubits, normalize=False)
)

Expand Down
33 changes: 13 additions & 20 deletions tests/test_quantum_info_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_uniform_sampling_U3(backend, seed):
)
expectation_values = backend.cast(expectation_values)

mean_function = backend.torch.mean if backend.name == "pytorch" else np.mean

expectation_values = mean_function(expectation_values, axis=0)
expectation_values = backend.np.mean(expectation_values, axis=0)

backend.assert_allclose(expectation_values[0], expectation_values[1], atol=1e-1)
backend.assert_allclose(expectation_values[0], expectation_values[2], atol=1e-1)
Expand Down Expand Up @@ -176,7 +174,7 @@ def test_random_unitary(backend, measure):
matrix = random_unitary(dims, measure=measure, backend=backend)
matrix_dagger = np.transpose(np.conj(matrix))
matrix_inv = (
backend.torch.inverse(matrix)
backend.np.inverse(matrix)
if backend.name == "pytorch"
else np.linalg.inv(matrix)
)
Expand Down Expand Up @@ -464,9 +462,8 @@ def test_random_pauli(
)
else:
matrix = np.transpose(matrix, (1, 0, 2, 3))
kron = backend.torch.kron if backend.name == "pytorch" else np.kron
matrix = [reduce(kron, row) for row in matrix]
dot = backend.torch.matmul if backend.name == "pytorch" else np.dot
matrix = [reduce(backend.np.kron, row) for row in matrix]
dot = backend.np.matmul if backend.name == "pytorch" else np.dot
matrix = reduce(dot, matrix)

if subset is None:
Expand Down Expand Up @@ -558,13 +555,10 @@ def test_random_stochastic_matrix(backend):
dims = 4
random_stochastic_matrix(dims, seed=0.1, backend=backend)

sum_function = backend.torch.sum if backend.name == "pytorch" else np.sum
diag = backend.torch.diag if backend.name == "pytorch" else np.diag

# tests if matrix is row-stochastic
dims = 4
matrix = random_stochastic_matrix(dims, backend=backend)
sum_rows = sum_function(matrix, axis=1)
sum_rows = backend.np.sum(matrix, axis=1)

backend.assert_allclose(all(sum_rows < 1 + PRECISION_TOL), True)
backend.assert_allclose(all(sum_rows > 1 - PRECISION_TOL), True)
Expand All @@ -575,19 +569,18 @@ def test_random_stochastic_matrix(backend):
dims, diagonally_dominant=True, max_iterations=1000, backend=backend
)

sum_function = backend.torch.sum if backend.name == "pytorch" else np.sum
sum_rows = sum_function(matrix, axis=1)
sum_rows = backend.np.sum(matrix, axis=1)

backend.assert_allclose(all(sum_rows < 1 + PRECISION_TOL), True)
backend.assert_allclose(all(sum_rows > 1 - PRECISION_TOL), True)

backend.assert_allclose(all(2 * diag(matrix) - sum_rows > 0), True)
backend.assert_allclose(all(2 * backend.np.diag(matrix) - sum_rows > 0), True)

# tests if matrix is bistochastic
dims = 4
matrix = random_stochastic_matrix(dims, bistochastic=True, backend=backend)
sum_rows = sum_function(matrix, axis=1)
column_rows = sum_function(matrix, axis=0)
sum_rows = backend.np.sum(matrix, axis=1)
column_rows = backend.np.sum(matrix, axis=0)

backend.assert_allclose(all(sum_rows < 1 + PRECISION_TOL), True)
backend.assert_allclose(all(sum_rows > 1 - PRECISION_TOL), True)
Expand All @@ -604,17 +597,17 @@ def test_random_stochastic_matrix(backend):
max_iterations=1000,
backend=backend,
)
sum_rows = sum_function(matrix, axis=1)
column_rows = sum_function(matrix, axis=0)
sum_rows = backend.np.sum(matrix, axis=1)
column_rows = backend.np.sum(matrix, axis=0)

backend.assert_allclose(all(sum_rows < 1 + PRECISION_TOL), True)
backend.assert_allclose(all(sum_rows > 1 - PRECISION_TOL), True)

backend.assert_allclose(all(column_rows < 1 + PRECISION_TOL), True)
backend.assert_allclose(all(column_rows > 1 - PRECISION_TOL), True)

backend.assert_allclose(all(2 * diag(matrix) - sum_rows > 0), True)
backend.assert_allclose(all(2 * diag(matrix) - column_rows > 0), True)
backend.assert_allclose(all(2 * backend.np.diag(matrix) - sum_rows > 0), True)
backend.assert_allclose(all(2 * backend.np.diag(matrix) - column_rows > 0), True)

# tests warning for max_iterations
dims = 4
Expand Down

0 comments on commit 225903e

Please sign in to comment.