Skip to content

Commit

Permalink
Combine unit tests and consistency checks for both python and c++ Pau…
Browse files Browse the repository at this point in the history
…liString classes
  • Loading branch information
stand-by committed Aug 10, 2024
1 parent 238f3de commit e3f5893
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 88 deletions.
13 changes: 9 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""pytest configuration file for fast_pauli tests."""

import itertools as it
from typing import Callable

import numpy as np
import pytest

from fast_pauli.pypauli.helpers import pauli_matrices

# TODO: fixtures to wrap around numpy testing functions with default tolerances


@pytest.fixture
def paulis() -> dict[str | int, np.ndarray]:
Expand All @@ -17,15 +16,21 @@ def paulis() -> dict[str | int, np.ndarray]:


@pytest.fixture
def sample_pauli_strings(limit_strings: int = 1_000) -> list[str]:
def sample_pauli_strings() -> list[str]:
"""Fixture to provide sample Pauli strings for testing."""
strings = it.chain(
["I", "X", "Y", "Z"],
it.product("IXYZ", repeat=2),
it.product("IXYZ", repeat=3),
["XYZXYZ", "ZZZIII", "XYIZXYZ", "XXIYYIZZ", "ZIXIZYXX"],
)
return list(map("".join, strings))[:limit_strings]
return list(map("".join, strings))


@pytest.fixture
def pauli_strings_with_size() -> Callable:
"""Fixture to provide Pauli strings of desired size for testing."""
return lambda size: list(map(lambda x: "".join(x), it.product("IXYZ", repeat=size)))


@pytest.fixture(scope="function")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_pauli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@pytest.mark.parametrize("Pauli,", [(fp.Pauli)], ids=resolve_parameter_repr)
def test_pauli_wrapper(paulis: dict, Pauli: type) -> None: # noqa: N803
def test_basics(paulis: dict, Pauli: type) -> None: # noqa: N803
"""Test pauli wrapper in python land."""
np.testing.assert_array_equal(
Pauli().to_tensor(),
Expand All @@ -35,7 +35,7 @@ def test_pauli_wrapper(paulis: dict, Pauli: type) -> None: # noqa: N803


@pytest.mark.parametrize("Pauli,", [(fp.Pauli)], ids=resolve_parameter_repr)
def test_pauli_wrapper_multiply(paulis: dict, Pauli: type) -> None: # noqa: N803
def test_multiply(paulis: dict, Pauli: type) -> None: # noqa: N803
"""Test custom __mul__ in c++ wrapper."""
for p1, p2 in it.product("IXYZ", repeat=2):
c, pcpp = Pauli(p1).multiply(Pauli(p2))
Expand All @@ -47,7 +47,7 @@ def test_pauli_wrapper_multiply(paulis: dict, Pauli: type) -> None: # noqa: N80


@pytest.mark.parametrize("Pauli,", [(fp.Pauli)], ids=resolve_parameter_repr)
def test_pauli_wrapper_exceptions(Pauli: type) -> None: # noqa: N803
def test_exceptions(Pauli: type) -> None: # noqa: N803
"""Test that exceptions from c++ are raised and propagated correctly."""
with np.testing.assert_raises(ValueError):
Pauli("II")
Expand Down
Loading

0 comments on commit e3f5893

Please sign in to comment.