Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
renatomello committed Sep 17, 2024
1 parent fd88c40 commit 4cf8598
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/test_quantum_info_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
hellinger_fidelity,
hellinger_shot_error,
pqc_integral,
total_variation_distance,
)


Expand Down Expand Up @@ -212,6 +213,64 @@ def test_hellinger_shot_error(backend, validate, kind):
assert 2 * hellinger_error < hellinger_fid


@pytest.mark.parametrize("kind", [None, list])
@pytest.mark.parametrize("validate", [False, True])
def test_total_variation_distance(backend, validate, kind):
with pytest.raises(TypeError):
prob = np.random.rand(1, 2)
prob_q = np.random.rand(1, 5)
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, backend=backend)
with pytest.raises(TypeError):
prob = np.random.rand(1, 2)[0]
prob_q = np.array([])
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, backend=backend)
with pytest.raises(ValueError):
prob = np.array([-1, 2.0])
prob_q = np.random.rand(1, 5)[0]
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, validate=True, backend=backend)
with pytest.raises(ValueError):
prob = np.random.rand(1, 2)[0]
prob_q = np.array([1.0, 0.0])
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, validate=True, backend=backend)
with pytest.raises(ValueError):
prob = np.array([1.0, 0.0])
prob_q = np.random.rand(1, 2)[0]
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, validate=True, backend=backend)

prob_p = np.random.rand(10)
prob_q = np.random.rand(10)
prob_p /= np.sum(prob_p)
prob_q /= np.sum(prob_q)
prob_p = backend.cast(prob_p, dtype=prob_p.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)

target = float(backend.calculate_norm(prob_p - prob_q, order=1) / 2)

prob_p = (
kind(prob_p) if kind is not None else backend.cast(prob_p, dtype=prob_p.dtype)
)
prob_q = (
kind(prob_q) if kind is not None else backend.cast(prob_q, dtype=prob_q.dtype)
)

tvd = total_variation_distance(prob_p, prob_q, validate, backend)
distance = hellinger_distance(prob_p, prob_q, validate, backend)

assert tvd == target
assert tvd <= np.sqrt(2) * distance
assert tvd >= distance**2


def test_haar_integral_errors(backend):
with pytest.raises(TypeError):
nqubits, power_t, samples = 0.5, 2, 10
Expand Down

0 comments on commit 4cf8598

Please sign in to comment.