Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
stand-by committed Aug 7, 2024
1 parent f007b1d commit 4bff8d8
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
9 changes: 4 additions & 5 deletions fast_pauli/cpp/include/__pauli_string.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ struct PauliString {
* This function takes in transposed states with (n_dims x n_states) shape
*
* It computes following inner product
* \f$ \bra{\psi_t} \mathcal{x_{ti}\hat{P_i}} \ket{\psi_t} \f$
* \f$ \bra{\psi_t} \mathcal{\hat{P_i}} \ket{\psi_t} \f$
* for each state \f$ \ket{\psi_t} \f$ from provided batch.
*
* @tparam T The floating point base to use for all the complex numbers
Expand All @@ -315,8 +315,8 @@ struct PauliString {
*/
template <std::floating_point T>
std::vector<std::complex<T>> expected_value(
std::mdspan<std::complex<T> const, std::dextents<size_t, 2>> states_T,
std::complex<T> const c) const {
std::mdspan<std::complex<T> const, std::dextents<size_t, 2>> states_T)
const {
// Input check
if (states_T.extent(0) != dims())
throw std::invalid_argument(
Expand All @@ -330,9 +330,8 @@ struct PauliString {

std::vector<std::complex<T>> exp_val(states_T.extent(1), 0);
for (size_t i = 0; i < states_T.extent(0); ++i) {
const auto c_m_i = c * m[i];
for (size_t t = 0; t < states_T.extent(1); ++t) {
exp_val[t] += std::conj(states_T(i, t)) * c_m_i * states_T(k[i], t);
exp_val[t] += std::conj(states_T(i, t)) * m[i] * states_T(k[i], t);
}
}

Expand Down
23 changes: 15 additions & 8 deletions fast_pauli/cpp/src/fast_pauli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ namespace py = pybind11;
using namespace pybind11::literals;

namespace fast_pauli {

/**
* @brief Flatten row major matrix represented as a 2D-std::vector into single
* std::vector
*
* @tparam T The floating point base to use for all the complex numbers
* @return std::vector<std::complex<T>> flattened vector with rows concatenated
*/
template <std::floating_point T>
inline std::vector<std::complex<T>>
flatten_vector(std::vector<std::vector<std::complex<T>>> const &inputs) {
Expand All @@ -26,6 +34,7 @@ flatten_vector(std::vector<std::vector<std::complex<T>>> const &inputs) {

return flat;
}

} // namespace fast_pauli

PYBIND11_MODULE(_fast_pauli, m) {
Expand Down Expand Up @@ -95,27 +104,25 @@ PYBIND11_MODULE(_fast_pauli, m) {
.def(
"expected_value",
[](fp::PauliString const &self,
std::vector<std::complex<double>> state,
std::complex<double> coeff) {
std::vector<std::complex<double>> state) {
std::mdspan<std::complex<double> const, std::dextents<size_t, 2>>
span_state{state.data(), state.size(), 1};
return self.expected_value(span_state, coeff);
return self.expected_value(span_state);
},
"state"_a, "coeff"_a = std::complex<double>{1.0})
"state"_a)
.def(
"expected_value",
[](fp::PauliString const &self,
std::vector<std::vector<std::complex<double>>> states,
std::complex<double> coeff) {
std::vector<std::vector<std::complex<double>>> states) {
if (states.empty())
return std::vector<std::complex<double>>{};
auto flat_states = fp::flatten_vector(states);
std::mdspan<std::complex<double> const, std::dextents<size_t, 2>>
span_states{flat_states.data(), states.size(),
states.front().size()};
return self.expected_value(span_states, coeff);
return self.expected_value(span_states);
},
"states"_a, "coeff"_a = std::complex<double>{1.0})
"states"_a)
.def("__str__",
[](fp::PauliString const &self) { return fmt::format("{}", self); });

Expand Down
6 changes: 2 additions & 4 deletions fast_pauli/py/pypauli/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def multiply(self, states: np.ndarray, coeff: np.complex128 = 1.0) -> np.ndarray
else:
return values * states[columns]

def expected_value(
self, state: np.ndarray, coeff: np.complex128 = 1.0
) -> np.complex128 | np.ndarray:
def expected_value(self, state: np.ndarray) -> np.complex128 | np.ndarray:
"""Compute the expected value of Pauli string for a given state.
Args:
Expand All @@ -87,7 +85,7 @@ def expected_value(
The expected value of the Pauli string with the state.
"""
return np.multiply(state.conj(), self.multiply(state, coeff)).sum(axis=0)
return np.multiply(state.conj(), self.multiply(state)).sum(axis=0)


def compose_sparse_pauli(string: str) -> tuple[np.ndarray, np.ndarray]:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ def test_expected_value(
atol=1e-15,
)

coeff = generate_random_complex(1)[0]
psis = generate_random_complex(n_states, n_dim)
expected = np.einsum("ti,ij,tj->t", psi.conj(), naive_pauil_converter(s), psi)
# compute <psi_t|P_i|psi_t>
expected = np.einsum("ti,ij,tj->t", psis.conj(), naive_pauli_converter(s), psis)
np.testing.assert_allclose(
PauliString(s).expected_value(psis, coeff),
PauliString(s).expected_value(psis.T),
expected,
atol=1e-15,
)
Expand Down

0 comments on commit 4bff8d8

Please sign in to comment.