Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pauli string python #7

Merged
merged 10 commits into from
Jul 18, 2024
164 changes: 164 additions & 0 deletions benchmarks/pauli_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import numpy as np
from dataclasses import dataclass


def pauli_matrices() -> dict:
s0 = np.array([[1, 0], [0, 1]], dtype=np.complex128)
s1 = np.array([[0, 1], [1, 0]], dtype=np.complex128)
s2 = np.array([[0, -1j], [1j, 0]], dtype=np.complex128)
s3 = np.array([[1, 0], [0, -1]], dtype=np.complex128)
return {"I": s0, "X": s1, "Y": s2, "Z": s3, 0: s0, 1: s1, 2: s2, 3: s3}


@dataclass
class PauliString:
string: str
weight: float = 1.0

def dense(self) -> np.ndarray:
paulis = pauli_matrices()
matrix = paulis[self.string[-1]]
for p in reversed(self.string[:-1]):
matrix = np.kron(paulis[p], matrix)
return self.weight * matrix
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved


# TODO more validation for the shape of inputs
@dataclass
class SparsePauliString:
columns: np.ndarray
values: np.ndarray
weight: float = 1.0
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved

def multiply(self, state: np.ndarray) -> np.ndarray:
if state.ndim == 1:
return self.weight * self.values * state[self.columns]
elif state.ndim == 2:
return self.weight * self.values[:, np.newaxis] * state[self.columns]
else:
raise ValueError("state must be a 1D or 2D array")

def dense(self) -> np.ndarray:
matrix = np.zeros((len(self.columns), len(self.columns)), dtype=np.complex128)
matrix[np.arange(len(self.columns)), self.columns] = self.weight * self.values
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved
return matrix


@dataclass
class SparseMatrix:
rows: np.ndarray
columns: np.ndarray
values: np.ndarray

jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved

class PauliComposer:
def __init__(self, pauli: PauliString) -> None:
self.pauli = pauli
self.n_qubits = len(pauli.string)
self.n_vals = 1 << self.n_qubits
self.n_ys = pauli.string.count("Y")
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved

def __resolve_init_conditions(self) -> None:
first_col = 0
for p in self.pauli.string:
first_col <<= 1
if p == "X" or p == "Y":
first_col += 1

match self.n_ys % 4:
case 0:
first_val = 1.0
case 1:
first_val = -1.0j
case 2:
first_val = -1.0
case 3:
first_val = 1.0j

return first_col, first_val

def sparse_pauli(self) -> SparsePauliString:
cols = np.empty(self.n_vals, dtype=np.int32)
vals = np.empty(self.n_vals, dtype=np.complex128)
cols[0], vals[0] = self.__resolve_init_conditions()

for q in range(self.n_qubits):
p = self.pauli.string[self.n_qubits - q - 1]
pow_of_two = 1 << q

new_slice = slice(pow_of_two, 2 * pow_of_two)
old_slice = slice(0, pow_of_two)

match p:
case "I":
cols[new_slice] = cols[old_slice] + pow_of_two
vals[new_slice] = vals[old_slice]
case "X":
cols[new_slice] = cols[old_slice] - pow_of_two
vals[new_slice] = vals[old_slice]
case "Y":
cols[new_slice] = cols[old_slice] - pow_of_two
vals[new_slice] = -vals[old_slice]
case "Z":
cols[new_slice] = cols[old_slice] + pow_of_two
vals[new_slice] = -vals[old_slice]

return SparsePauliString(weight=self.pauli.weight, columns=cols, values=vals)

def sparse_diag_pauli(self) -> SparsePauliString:
assert self.pauli.string.count("X") + self.pauli.string.count("Y") == 0

cols = np.arange(self.n_vals, dtype=np.int32)
vals = np.ones(self.n_vals, dtype=np.complex128)

for q in range(self.n_qubits):
p = self.pauli.string[self.n_qubits - q - 1]
pow_of_two = 1 << q

new_slice = slice(pow_of_two, 2 * pow_of_two)
old_slice = slice(0, pow_of_two)

match p:
case "I":
vals[new_slice] = vals[old_slice]
case "Z":
vals[new_slice] = -vals[old_slice]

return SparsePauliString(weight=self.pauli.weight, columns=cols, values=vals)

def efficient_sparse_multiply(self, state: np.ndarray) -> np.ndarray:
assert state.ndim == 2

cols = np.empty(self.n_vals, dtype=np.int32)
vals = np.empty(self.n_vals, dtype=np.complex128)
cols[0], vals[0] = self.__resolve_init_conditions()

product = np.empty((self.n_vals, state.shape[1]), dtype=np.complex128)
product[0] = self.pauli.weight * vals[0] * state[cols[0]]

for q in range(self.n_qubits):
p = self.pauli.string[self.n_qubits - q - 1]
pow_of_two = 1 << q

new_slice = slice(pow_of_two, 2 * pow_of_two)
old_slice = slice(0, pow_of_two)

match p:
case "I":
cols[new_slice] = cols[old_slice] + pow_of_two
vals[new_slice] = vals[old_slice]
case "X":
cols[new_slice] = cols[old_slice] - pow_of_two
vals[new_slice] = vals[old_slice]
case "Y":
cols[new_slice] = cols[old_slice] - pow_of_two
vals[new_slice] = -vals[old_slice]
case "Z":
cols[new_slice] = cols[old_slice] + pow_of_two
vals[new_slice] = -vals[old_slice]

product[new_slice] = (
self.pauli.weight * vals[new_slice, np.newaxis] * state[cols[new_slice]]
)

return product
167 changes: 167 additions & 0 deletions benchmarks/test_pauli_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import pytest
import numpy as np
from itertools import permutations, chain

import pauli_operations
from pauli_operations import PauliString, SparsePauliString, PauliComposer


@pytest.fixture
def paulis():
return pauli_operations.pauli_matrices()


def test_pauli_strings(paulis):
for p in "IXYZ":
np.testing.assert_array_equal(PauliString(p).dense(), paulis[p])
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved

ps = PauliString("III", 0.5)
np.testing.assert_array_equal(ps.dense(), np.eye(8) * 0.5)

ps = PauliString(weight=1.0, string="IZ")
np.testing.assert_array_equal(ps.dense(), np.kron(paulis["I"], paulis["Z"]))

ps = PauliString(weight=0.5, string="XYZ")
np.testing.assert_array_equal(
ps.dense(), np.kron(paulis["X"], np.kron(paulis["Y"], paulis["Z"])) * 0.5
)

ps = SparsePauliString(np.arange(8), np.ones(8), 0.5)
np.testing.assert_array_equal(ps.dense(), np.eye(8) * 0.5)
m = np.array([[0, 1, 0], [0, 0, 2], [3, 0, 0]])
ps = SparsePauliString(columns=np.array([1, 2, 0]), values=np.array([1, 2, 3]))
np.testing.assert_array_equal(ps.dense(), m)


def test_pauli_composer(paulis):
for p in "IXYZ":
pc = PauliComposer(PauliString(p))
assert pc.n_qubits == 1
assert pc.n_vals == 2
assert pc.n_ys == 0 or p == "Y"
assert pc.n_ys == 1 or p != "Y"
np.testing.assert_array_equal(pc.sparse_pauli().dense(), paulis[p])

pc = PauliComposer(PauliString("II", weight=0.2))
assert pc.n_vals == 4
np.testing.assert_array_equal(pc.sparse_pauli().dense(), np.eye(4) * 0.2)
np.testing.assert_array_equal(pc.sparse_diag_pauli().dense(), np.eye(4) * 0.2)

pc = PauliComposer(PauliString("IIII"))
assert pc.n_vals == 16
np.testing.assert_array_equal(pc.sparse_pauli().dense(), np.eye(16))
np.testing.assert_array_equal(pc.sparse_diag_pauli().dense(), np.eye(16))

pc = PauliComposer(PauliString("II", weight=0.2))
assert pc.n_vals == 4
np.testing.assert_array_equal(pc.sparse_pauli().dense(), np.eye(4) * 0.2)

pc = PauliComposer(PauliString("XXX", weight=1.0))
assert pc.n_vals == 8
np.testing.assert_array_equal(pc.sparse_pauli().dense(), np.fliplr(np.eye(8)))

pc = PauliComposer(PauliString("IY"))
np.testing.assert_array_equal(
pc.sparse_pauli().dense(),
np.block([[paulis["Y"], np.zeros((2, 2))], [np.zeros((2, 2)), paulis["Y"]]]),
)

pc = PauliComposer(PauliString("IZ"))
np.testing.assert_array_equal(
pc.sparse_pauli().dense(),
np.block([[paulis["Z"], np.zeros((2, 2))], [np.zeros((2, 2)), paulis["Z"]]]),
)


def test_pauli_composer_equivalence():
rng = np.random.default_rng(321)

for c in "IXYZ":
w = rng.random()
np.testing.assert_array_equal(
PauliComposer(PauliString(c, w)).sparse_pauli().dense(),
PauliString(c, w).dense(),
)

for s in permutations("XYZ", 2):
s = "".join(s)
w = rng.random()
np.testing.assert_array_equal(
PauliComposer(PauliString(s, w)).sparse_pauli().dense(),
PauliString(s, w).dense(),
)

for s in permutations("IXYZ", 3):
s = "".join(s)
w = rng.random()
np.testing.assert_array_equal(
PauliComposer(PauliString(s, w)).sparse_pauli().dense(),
PauliString(s, w).dense(),
)

ixyz = PauliComposer(PauliString("IXYZ")).sparse_pauli().dense()
np.testing.assert_array_equal(ixyz, PauliString("IXYZ").dense())

zyxi = PauliComposer(PauliString("ZYXI")).sparse_pauli().dense()
np.testing.assert_array_equal(zyxi, PauliString("ZYXI").dense())

assert np.abs(ixyz - zyxi).sum().sum() > 1e-10

for s in ["XYIZXYZ", "XXIYYIZZ", "ZIXIZYXX"]:
np.testing.assert_array_equal(
PauliComposer(PauliString(s)).sparse_pauli().dense(), PauliString(s).dense()
)


def test_sparse_pauli_multiply():
rng = np.random.default_rng(321)

for s in chain(
list("IXYZ"), list(permutations("IXYZ", 3)), ["XYIZXYZ", "XXIYYIZZ", "ZIXIZYXX"]
):
s = "".join(s)
n = 2 ** len(s)
w = rng.random()
psi = rng.random(n)
psi_batch = rng.random((n, 25))

np.testing.assert_allclose(
PauliComposer(PauliString(s, w)).sparse_pauli().multiply(psi),
PauliString(s, w).dense().dot(psi),
atol=1e-15,
)
np.testing.assert_allclose(
PauliComposer(PauliString(s, w)).sparse_pauli().multiply(psi_batch),
PauliString(s, w).dense() @ psi_batch,
atol=1e-15,
)


def test_pauli_composer_multiply():
rng = np.random.default_rng(321)

for s in chain(
list("IXYZ"), list(permutations("IXYZ", 3)), ["XYIZXYZ", "XXIYYIZZ", "ZIXIZYXX"]
):
s = "".join(s)
n = 2 ** len(s)
w = rng.random()
psi = rng.random(n)
psi_batch = rng.random((n, 20))

np.testing.assert_allclose(
PauliComposer(PauliString(s, w))
.efficient_sparse_multiply(psi.reshape(-1, 1))
.ravel(),
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved
PauliString(s, w).dense().dot(psi),
atol=1e-15,
)
np.testing.assert_allclose(
PauliComposer(PauliString(s, w)).efficient_sparse_multiply(psi_batch),
PauliString(s, w).dense() @ psi_batch,
atol=1e-15,
)


if __name__ == "__main__":
pytest.main([__file__])
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
numpy
scipy
numba
pytest
jamesETsmith marked this conversation as resolved.
Show resolved Hide resolved
Loading