From 4cf85983bba90b508f100e355d29f37dc08660a5 Mon Sep 17 00:00:00 2001 From: Renato Mello Date: Tue, 17 Sep 2024 15:59:38 +0400 Subject: [PATCH] test --- tests/test_quantum_info_utils.py | 59 ++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/test_quantum_info_utils.py b/tests/test_quantum_info_utils.py index 8a375e0f21..238241348e 100644 --- a/tests/test_quantum_info_utils.py +++ b/tests/test_quantum_info_utils.py @@ -17,6 +17,7 @@ hellinger_fidelity, hellinger_shot_error, pqc_integral, + total_variation_distance, ) @@ -212,6 +213,64 @@ def test_hellinger_shot_error(backend, validate, kind): assert 2 * hellinger_error < hellinger_fid +@pytest.mark.parametrize("kind", [None, list]) +@pytest.mark.parametrize("validate", [False, True]) +def test_total_variation_distance(backend, validate, kind): + with pytest.raises(TypeError): + prob = np.random.rand(1, 2) + prob_q = np.random.rand(1, 5) + prob = backend.cast(prob, dtype=prob.dtype) + prob_q = backend.cast(prob_q, dtype=prob_q.dtype) + test = total_variation_distance(prob, prob_q, backend=backend) + with pytest.raises(TypeError): + prob = np.random.rand(1, 2)[0] + prob_q = np.array([]) + prob = backend.cast(prob, dtype=prob.dtype) + prob_q = backend.cast(prob_q, dtype=prob_q.dtype) + test = total_variation_distance(prob, prob_q, backend=backend) + with pytest.raises(ValueError): + prob = np.array([-1, 2.0]) + prob_q = np.random.rand(1, 5)[0] + prob = backend.cast(prob, dtype=prob.dtype) + prob_q = backend.cast(prob_q, dtype=prob_q.dtype) + test = total_variation_distance(prob, prob_q, validate=True, backend=backend) + with pytest.raises(ValueError): + prob = np.random.rand(1, 2)[0] + prob_q = np.array([1.0, 0.0]) + prob = backend.cast(prob, dtype=prob.dtype) + prob_q = backend.cast(prob_q, dtype=prob_q.dtype) + test = total_variation_distance(prob, prob_q, validate=True, backend=backend) + with pytest.raises(ValueError): + prob = np.array([1.0, 0.0]) + prob_q = np.random.rand(1, 2)[0] + prob = backend.cast(prob, dtype=prob.dtype) + prob_q = backend.cast(prob_q, dtype=prob_q.dtype) + test = total_variation_distance(prob, prob_q, validate=True, backend=backend) + + prob_p = np.random.rand(10) + prob_q = np.random.rand(10) + prob_p /= np.sum(prob_p) + prob_q /= np.sum(prob_q) + prob_p = backend.cast(prob_p, dtype=prob_p.dtype) + prob_q = backend.cast(prob_q, dtype=prob_q.dtype) + + target = float(backend.calculate_norm(prob_p - prob_q, order=1) / 2) + + prob_p = ( + kind(prob_p) if kind is not None else backend.cast(prob_p, dtype=prob_p.dtype) + ) + prob_q = ( + kind(prob_q) if kind is not None else backend.cast(prob_q, dtype=prob_q.dtype) + ) + + tvd = total_variation_distance(prob_p, prob_q, validate, backend) + distance = hellinger_distance(prob_p, prob_q, validate, backend) + + assert tvd == target + assert tvd <= np.sqrt(2) * distance + assert tvd >= distance**2 + + def test_haar_integral_errors(backend): with pytest.raises(TypeError): nqubits, power_t, samples = 0.5, 2, 10