Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stim.FlipSimulator #612

Merged
merged 9 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
624 changes: 624 additions & 0 deletions doc/python_api_reference_vDev.md

Large diffs are not rendered by default.

520 changes: 520 additions & 0 deletions doc/stim.pyi

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions file_lists/python_api_files
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ src/stim/py/march.pybind.cc
src/stim/py/numpy.pybind.cc
src/stim/py/stim.pybind.cc
src/stim/simulators/dem_sampler.pybind.cc
src/stim/simulators/frame_simulator.pybind.cc
src/stim/simulators/matched_error.pybind.cc
src/stim/simulators/measurements_to_detection_events.pybind.cc
src/stim/simulators/tableau_simulator.pybind.cc
Expand Down
520 changes: 520 additions & 0 deletions glue/python/src/stim/__init__.pyi

Large diffs are not rendered by default.

52 changes: 1 addition & 51 deletions src/stim/circuit/circuit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -706,59 +706,9 @@ size_t Circuit::count_sweep_bits() const {

CircuitStats Circuit::compute_stats() const {
CircuitStats total;

for (const auto &op : operations) {
if (op.gate_type == REPEAT) {
// Recurse into blocks.
auto sub = op.repeat_block_body(*this).compute_stats();
auto reps = op.repeat_block_rep_count();
total.num_observables = std::max(total.num_observables, sub.num_observables);
total.num_qubits = std::max(total.num_qubits, sub.num_qubits);
total.max_lookback = std::max(total.max_lookback, sub.max_lookback);
total.num_sweep_bits = std::max(total.num_sweep_bits, sub.num_sweep_bits);
total.num_detectors = add_saturate(total.num_detectors, mul_saturate(sub.num_detectors, reps));
total.num_measurements = add_saturate(total.num_measurements, mul_saturate(sub.num_measurements, reps));
total.num_ticks = add_saturate(total.num_ticks, mul_saturate(sub.num_ticks, reps));
continue;
}

for (auto t : op.targets) {
auto v = t.data & TARGET_VALUE_MASK;
// Qubit counting.
if (!(t.data & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) {
total.num_qubits = std::max(total.num_qubits, v + 1);
}
// Lookback counting.
if (t.data & TARGET_RECORD_BIT) {
total.max_lookback = std::max(total.max_lookback, v);
}
// Sweep bit counting.
if (t.data & TARGET_SWEEP_BIT) {
total.num_sweep_bits = std::max(total.num_sweep_bits, v + 1);
}
}

// Measurement counting.
total.num_measurements += op.count_measurement_results();

switch (op.gate_type) {
case GateType::DETECTOR:
// Detector counting.
total.num_detectors += total.num_detectors < UINT64_MAX;
break;
case GateType::OBSERVABLE_INCLUDE:
// Observable counting.
total.num_observables = std::max(total.num_observables, (uint64_t)op.args[0] + 1);
break;
case GateType::TICK:
// Tick counting.
total.num_ticks++;
break;
default:
break;
}
op.add_stats_to(total, this);
}

return total;
}

Expand Down
11 changes: 0 additions & 11 deletions src/stim/circuit/circuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ namespace stim {
uint64_t add_saturate(uint64_t a, uint64_t b);
uint64_t mul_saturate(uint64_t a, uint64_t b);

/// Stores a variety of circuit quantities relevant for sizing memory.
struct CircuitStats {
uint64_t num_detectors = 0;
uint64_t num_observables = 0;
uint64_t num_measurements = 0;
uint32_t num_qubits = 0;
uint32_t num_ticks = 0;
uint32_t max_lookback = 0;
uint32_t num_sweep_bits = 0;
};

/// A description of a quantum computation.
struct Circuit {
/// Backing data stores for variable-sized target data referenced by operations.
Expand Down
61 changes: 61 additions & 0 deletions src/stim/circuit/circuit_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,67 @@ Circuit &CircuitInstruction::repeat_block_body(Circuit &host) const {
return host.blocks[b];
}

CircuitStats CircuitInstruction::compute_stats(const Circuit *host) const {
CircuitStats out;
add_stats_to(out, host);
return out;
}

void CircuitInstruction::add_stats_to(CircuitStats &out, const Circuit *host) const {
if (gate_type == REPEAT) {
if (host == nullptr) {
throw std::invalid_argument("gate_type == REPEAT && host == nullptr");
}
// Recurse into blocks.
auto sub = repeat_block_body(*host).compute_stats();
auto reps = repeat_block_rep_count();
out.num_observables = std::max(out.num_observables, sub.num_observables);
out.num_qubits = std::max(out.num_qubits, sub.num_qubits);
out.max_lookback = std::max(out.max_lookback, sub.max_lookback);
out.num_sweep_bits = std::max(out.num_sweep_bits, sub.num_sweep_bits);
out.num_detectors = add_saturate(out.num_detectors, mul_saturate(sub.num_detectors, reps));
out.num_measurements = add_saturate(out.num_measurements, mul_saturate(sub.num_measurements, reps));
out.num_ticks = add_saturate(out.num_ticks, mul_saturate(sub.num_ticks, reps));
return;
}

for (auto t : targets) {
auto v = t.data & TARGET_VALUE_MASK;
// Qubit counting.
if (!(t.data & (TARGET_RECORD_BIT | TARGET_SWEEP_BIT))) {
out.num_qubits = std::max(out.num_qubits, v + 1);
}
// Lookback counting.
if (t.data & TARGET_RECORD_BIT) {
out.max_lookback = std::max(out.max_lookback, v);
}
// Sweep bit counting.
if (t.data & TARGET_SWEEP_BIT) {
out.num_sweep_bits = std::max(out.num_sweep_bits, v + 1);
}
}

// Measurement counting.
out.num_measurements += count_measurement_results();

switch (gate_type) {
case GateType::DETECTOR:
// Detector counting.
out.num_detectors += out.num_detectors < UINT64_MAX;
break;
case GateType::OBSERVABLE_INCLUDE:
// Observable counting.
out.num_observables = std::max(out.num_observables, (uint64_t)args[0] + 1);
break;
case GateType::TICK:
// Tick counting.
out.num_ticks++;
break;
default:
break;
}
}

const Circuit &CircuitInstruction::repeat_block_body(const Circuit &host) const {
assert(targets.size() == 3);
auto b = targets[0].data;
Expand Down
28 changes: 28 additions & 0 deletions src/stim/circuit/circuit_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ namespace stim {

struct Circuit;

/// Stores a variety of circuit quantities relevant for sizing memory.
struct CircuitStats {
uint64_t num_detectors = 0;
uint64_t num_observables = 0;
uint64_t num_measurements = 0;
uint32_t num_qubits = 0;
uint64_t num_ticks = 0;
uint32_t max_lookback = 0;
uint32_t num_sweep_bits = 0;

inline CircuitStats repeated(uint64_t repetitions) const {
return CircuitStats{
num_detectors * repetitions,
num_observables,
num_measurements * repetitions,
num_qubits,
(uint32_t)(num_ticks * repetitions),
max_lookback,
num_sweep_bits,
};
}
};

/// The data that describes how a gate is being applied to qubits (or other targets).
///
/// A gate applied to targets.
Expand All @@ -49,6 +72,11 @@ struct CircuitInstruction {
CircuitInstruction() = delete;
CircuitInstruction(GateType gate_type, SpanRef<const double> args, SpanRef<const GateTarget> targets);

/// Computes number of qubits, number of measurements, etc.
CircuitStats compute_stats(const Circuit *host) const;
/// Computes number of qubits, number of measurements, etc and adds them into a target.
void add_stats_to(CircuitStats &out, const Circuit *host) const;

/// Determines if two operations can be combined into one operation (with combined targeting data).
///
/// For example, `H 1` then `H 2 1` is equivalent to `H 1 2 1` so those instructions are fusable.
Expand Down
10 changes: 10 additions & 0 deletions src/stim/mem/simd_bit_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ struct simd_bit_table {
/// Resizes the table. Doesn't clear to zero. Does nothing if already the target size.
void destructive_resize(size_t new_min_bits_major, size_t new_min_bits_minor);

/// Copies the table into another table.
///
/// It's safe for the other table to have a different size.
/// When the other table has a different size, only the data at locations common to both
/// tables are copied over.
void copy_into_different_size_table(simd_bit_table<W> &other) const;

/// Resizes the table, keeping any data common to the old and new size and otherwise zeroing data.
void resize(size_t new_min_bits_major, size_t new_min_bits_minor);

/// Equality.
bool operator==(const simd_bit_table &other) const;
/// Inequality.
Expand Down
42 changes: 42 additions & 0 deletions src/stim/mem/simd_bit_table.inl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,36 @@ void simd_bit_table<W>::destructive_resize(size_t new_min_bits_major, size_t new
data.destructive_resize(num_simd_words_minor * num_simd_words_major * W * W);
}

template <size_t W>
void simd_bit_table<W>::copy_into_different_size_table(simd_bit_table<W> &other) const {
size_t ni = num_simd_words_minor;
size_t na = num_simd_words_major;
size_t mi = other.num_simd_words_minor;
size_t ma = other.num_simd_words_major;
size_t num_min_bytes = std::min(ni, mi) * (W / 8);
size_t num_maj = std::min(na, ma) * W;

if (ni == mi) {
memcpy(other.data.ptr_simd, data.ptr_simd, num_min_bytes * num_maj);
} else {
for (size_t maj = 0; maj < num_maj; maj++) {
memcpy(other[maj].ptr_simd, (*this)[maj].ptr_simd, num_min_bytes);
}
}
}

template <size_t W>
void simd_bit_table<W>::resize(size_t new_min_bits_major, size_t new_min_bits_minor) {
auto new_num_simd_words_minor = min_bits_to_num_simd_words<W>(new_min_bits_minor);
auto new_num_simd_words_major = min_bits_to_num_simd_words<W>(new_min_bits_major);
if (new_num_simd_words_major == num_simd_words_major && new_num_simd_words_minor == num_simd_words_minor) {
return;
}
auto new_table = simd_bit_table<W>(new_min_bits_major, new_min_bits_minor);
copy_into_different_size_table(new_table);
*this = std::move(new_table);
}

template <size_t W>
void simd_bit_table<W>::do_square_transpose() {
assert(num_simd_words_minor == num_simd_words_major);
Expand Down Expand Up @@ -138,6 +168,18 @@ simd_bit_table<W> simd_bit_table<W>::transposed() const {
return result;
}

template <size_t W>
simd_bits<W> simd_bit_table<W>::read_across_majors_at_minor_index(size_t major_start, size_t major_stop, size_t minor_index) const {
assert(major_stop >= major_start);
assert(major_stop <= num_major_bits_padded());
assert(minor_index < num_minor_bits_padded());
simd_bits<W> result(major_stop - major_start);
for (size_t maj = major_start; maj < major_stop; maj++) {
result[maj - major_start] = (*this)[maj][minor_index];
}
return result;
}

template <size_t W>
simd_bit_table<W> simd_bit_table<W>::slice_maj(size_t maj_start_bit, size_t maj_stop_bit) const {
simd_bit_table<W> result(maj_stop_bit - maj_start_bit, num_minor_bits_padded());
Expand Down
106 changes: 105 additions & 1 deletion src/stim/mem/simd_bit_table.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ TEST(simd_bit_table, lg) {

TEST_EACH_WORD_SIZE_W(simd_bit_table, destructive_resize, {
auto rng = INDEPENDENT_TEST_RNG();
simd_bit_table<W> table = table.random(5, 7, rng);
simd_bit_table<W> table = simd_bit_table<W>::random(5, 7, rng);
const uint8_t *prev_pointer = table.data.u8;
table.destructive_resize(5, 7);
ASSERT_EQ(table.data.u8, prev_pointer);
Expand All @@ -302,3 +302,107 @@ TEST_EACH_WORD_SIZE_W(simd_bit_table, destructive_resize, {
ASSERT_GE(table.num_major_bits_padded(), 1025);
ASSERT_GE(table.num_minor_bits_padded(), 7);
})

TEST_EACH_WORD_SIZE_W(simd_bit_table, read_across_majors_at_minor_index, {
auto rng = INDEPENDENT_TEST_RNG();
simd_bit_table<W> table = simd_bit_table<W>::random(5, 7, rng);
simd_bits<W> slice = table.read_across_majors_at_minor_index(2, 5, 1);
ASSERT_GE(slice.num_bits_padded(), 4);
ASSERT_EQ(slice[0], table[2][1]);
ASSERT_EQ(slice[1], table[3][1]);
ASSERT_EQ(slice[2], table[4][1]);
ASSERT_EQ(slice[3], false);
})

template <size_t W>
bool is_table_overlap_identical(const simd_bit_table<W> &a, const simd_bit_table<W> &b) {
size_t w_min = std::min(a.num_simd_words_minor, b.num_simd_words_minor);
size_t n_maj = std::min(a.num_major_bits_padded(), b.num_major_bits_padded());
for (size_t k_maj = 0; k_maj < n_maj; k_maj++) {
if (a[k_maj].word_range_ref(0, w_min) != b[k_maj].word_range_ref(0, w_min)) {
return false;
}
}
return true;
}

template <size_t W>
bool is_table_zero_outside(const simd_bit_table<W> &a, size_t num_major_bits, size_t num_minor_bits) {
size_t num_major_words = min_bits_to_num_simd_words<W>(num_major_bits);
size_t num_minor_words = min_bits_to_num_simd_words<W>(num_minor_bits);
if (a.num_simd_words_minor > num_minor_words) {
for (size_t k = 0; k < a.num_simd_words_major; k++) {
if (a[k].word_range_ref(num_minor_words, a.num_simd_words_minor - num_minor_words).not_zero()) {
return false;
}
}
}
for (size_t k = a.num_simd_words_major; k < num_major_words; k++) {
if (a[k].not_zero()) {
return false;
}
}
return true;
}

TEST_EACH_WORD_SIZE_W(simd_bit_table, copy_into_different_size_table, {
auto rng = INDEPENDENT_TEST_RNG();

auto check_size = [&](size_t w1, size_t h1, size_t w2, size_t h2) {
simd_bit_table<W> src = simd_bit_table<W>::random(w1, h1, rng);
simd_bit_table<W> dst = simd_bit_table<W>::random(w1, h1, rng);
src.copy_into_different_size_table(dst);
return is_table_overlap_identical(src, dst);
};

EXPECT_TRUE(check_size(0, 0, 0, 0));

EXPECT_TRUE(check_size(64, 0, 0, 0));
EXPECT_TRUE(check_size(0, 64, 0, 0));
EXPECT_TRUE(check_size(0, 0, 64, 0));
EXPECT_TRUE(check_size(0, 0, 0, 64));

EXPECT_TRUE(check_size(64, 64, 64, 64));
EXPECT_TRUE(check_size(512, 64, 64, 64));
EXPECT_TRUE(check_size(64, 512, 64, 64));
EXPECT_TRUE(check_size(64, 64, 512, 64));
EXPECT_TRUE(check_size(64, 64, 64, 512));

EXPECT_TRUE(check_size(512, 512, 64, 64));
EXPECT_TRUE(check_size(512, 64, 512, 64));
EXPECT_TRUE(check_size(512, 64, 64, 512));
EXPECT_TRUE(check_size(64, 512, 512, 64));
EXPECT_TRUE(check_size(64, 512, 64, 512));
EXPECT_TRUE(check_size(64, 64, 512, 512));
})

TEST_EACH_WORD_SIZE_W(simd_bit_table, resize, {
auto rng = INDEPENDENT_TEST_RNG();

auto check_size = [&](size_t w1, size_t h1, size_t w2, size_t h2) {
simd_bit_table<W> src = simd_bit_table<W>::random(w1, h1, rng);
simd_bit_table<W> dst = src;
dst.resize(w2, h2);
return is_table_overlap_identical(src, dst) && is_table_zero_outside(dst, std::min(w1, w2), std::min(h1, h2));
};

EXPECT_TRUE(check_size(0, 0, 0, 0));

EXPECT_TRUE(check_size(64, 0, 0, 0));
EXPECT_TRUE(check_size(0, 64, 0, 0));
EXPECT_TRUE(check_size(0, 0, 64, 0));
EXPECT_TRUE(check_size(0, 0, 0, 64));

EXPECT_TRUE(check_size(64, 64, 64, 64));
EXPECT_TRUE(check_size(512, 64, 64, 64));
EXPECT_TRUE(check_size(64, 512, 64, 64));
EXPECT_TRUE(check_size(64, 64, 512, 64));
EXPECT_TRUE(check_size(64, 64, 64, 512));

EXPECT_TRUE(check_size(512, 512, 64, 64));
EXPECT_TRUE(check_size(512, 64, 512, 64));
EXPECT_TRUE(check_size(512, 64, 64, 512));
EXPECT_TRUE(check_size(64, 512, 512, 64));
EXPECT_TRUE(check_size(64, 512, 64, 512));
EXPECT_TRUE(check_size(64, 64, 512, 512));
})
3 changes: 2 additions & 1 deletion src/stim/py/compiled_detector_sampler.pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ pybind11::object CompiledDetectorSampler::sample_to_numpy(
}

frame_sim.configure_for(circuit_stats, FrameSimulatorMode::STORE_DETECTIONS_TO_MEMORY, num_shots);
frame_sim.reset_all_and_run(circuit);
frame_sim.reset_all();
frame_sim.do_circuit(circuit);

const auto &det_data = frame_sim.det_record.storage;
const auto &obs_data = frame_sim.obs_record;
Expand Down
Loading
Loading