Skip to content

Commit

Permalink
Add left/right shift and addition to simd_bits. (#603)
Browse files Browse the repository at this point in the history
From #598 added +=, >>= and <<= to simd_bits. It wasn't obvious to me that these could use word level parallelism without using more memory? For example, the shifts could store the relevant carry masks and or these at the end but this would require a temporary of the same size as the simd_bits instance.
  • Loading branch information
fdmalone authored Aug 16, 2023
1 parent 6fb8663 commit d44e4b6
Show file tree
Hide file tree
Showing 6 changed files with 390 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/stim/mem/simd_bits.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ struct simd_bits {
// Mask assignment.
simd_bits &operator&=(const simd_bits_range_ref<W> other);
simd_bits &operator|=(const simd_bits_range_ref<W> other);
// Addition assigment
simd_bits &operator+=(const simd_bits_range_ref<W> other);
// right shift assignment
simd_bits &operator>>=(int offset);
// left shift assignment
simd_bits &operator<<=(int offset);
// Swap assignment.
simd_bits &swap_with(simd_bits_range_ref<W> other);

Expand Down
20 changes: 19 additions & 1 deletion src/stim/mem/simd_bits.inl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,24 @@ simd_bits<W> &simd_bits<W>::operator|=(const simd_bits_range_ref<W> other) {
return *this;
}

template <size_t W>
simd_bits<W> &simd_bits<W>::operator+=(const simd_bits_range_ref<W> other) {
simd_bits_range_ref<W>(*this) += other;
return *this;
}

template <size_t W>
simd_bits<W> &simd_bits<W>::operator>>=(int offset) {
simd_bits_range_ref<W>(*this) >>= offset;
return *this;
}

template <size_t W>
simd_bits<W> &simd_bits<W>::operator<<=(int offset) {
simd_bits_range_ref<W>(*this) <<= offset;
return *this;
}

template <size_t W>
bool simd_bits<W>::not_zero() const {
return simd_bits_range_ref<W>(*this).not_zero();
Expand Down Expand Up @@ -289,4 +307,4 @@ std::ostream &operator<<(std::ostream &out, const simd_bits<W> m) {
return out << simd_bits_range_ref<W>(m);
}

}
} // namespace stim
193 changes: 193 additions & 0 deletions src/stim/mem/simd_bits.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "stim/mem/simd_bits.h"

#include <random>

#include "gtest/gtest.h"

#include "stim/mem/simd_util.h"
Expand Down Expand Up @@ -160,6 +162,197 @@ TEST_EACH_WORD_SIZE_W(simd_bits, xor_assignment, {
}
})

TEST_EACH_WORD_SIZE_W(simd_bits, add_assignment, {
simd_bits<W> m0(512);
simd_bits<W> m1(512);
uint64_t all_set = 0xFFFFFFFFFFFFFFFFULL;
uint64_t on_off = 0x0F0F0F0F0F0F0F0FULL;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
for (size_t k = 0; k < 64; k++) {
if (word % 2 == 0) {
m0[word * 64 + k] = all_set & (1ULL << k);
m1[word * 64 + k] = all_set & (1ULL << k);
} else {
m0[word * 64 + k] = (bool)(on_off & (1ULL << k));
m1[word * 64 + k] = (bool)(on_off & (1ULL << k));
}
}
}
m0 += m1;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
uint64_t pattern = 0ULL;
for (size_t k = 0; k < 64; k++) {
pattern |= (uint64_t{m0[word * 64 + k]} << k);
}
if (word % 2 == 0) {
ASSERT_EQ(pattern, 0xFFFFFFFFFFFFFFFEULL);
} else {
ASSERT_EQ(pattern, 0x1E1E1E1E1E1E1E1FULL);
}
}
for (size_t k = 0; k < m0.num_u64_padded() / 2; k++) {
m1.u64[2 * k] = 0ULL;
m1.u64[2 * k + 1] = 0ULL;
}
m0 += m1;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
uint64_t pattern = 0ULL;
for (size_t k = 0; k < 64; k++) {
pattern |= (uint64_t{m0[word * 64 + k]} << k);
}
if (word % 2 == 0) {
ASSERT_EQ(pattern, 0xFFFFFFFFFFFFFFFEULL);
} else {
ASSERT_EQ(pattern, 0x1E1E1E1E1E1E1E1FULL);
}
}
m0.clear();
m1.clear();
m1[0] = 1;
for (int i = 0; i < 512; i++) {
m0 += m1;
}
for (size_t k = 0; k < 64; k++) {
if (k == 9) {
ASSERT_EQ(m0[k], 1);
} else {
ASSERT_EQ(m0[k], 0);
}
}
m0.clear();
for (size_t k = 0; k < 64; k++) {
m0[k] = all_set & (1ULL << k);
}
m0 += m1;
ASSERT_EQ(m0[0], 0);
ASSERT_EQ(m0[64], 1);
})

TEST_EACH_WORD_SIZE_W(simd_bits, right_shift_assignment, {
simd_bits<W> m0(512), m1(512);
m0[511] = 1;
m0 >>= 64;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
uint64_t pattern = 0ULL;
for (size_t k = 0; k < 64; k++) {
pattern |= (uint64_t{m0[word * 64 + k]} << k);
}
if (word != m0.num_u64_padded() - 2) {
ASSERT_EQ(pattern, 0ULL);
} else {
ASSERT_EQ(pattern, uint64_t{1} << 63);
}
}
m1 = m0;
m1 >>= 0;
for (size_t k = 0; k < 512; k++) {
ASSERT_EQ(m0[k], m1[k]);
}
m0.clear();
uint64_t on_off = 0xAAAAAAAAAAAAAAAAULL;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
for (size_t k = 0; k < 64; k++) {
m0[word * 64 + k] = (bool)(on_off & (1ULL << k));
}
}
m0 >>= 1;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
uint64_t pattern = 0ULL;
for (size_t k = 0; k < 64; k++) {
pattern |= (uint64_t{m0[word * 64 + k]} << k);
}
ASSERT_EQ(pattern, 0x5555555555555555ULL);
}
m0.clear();
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
for (size_t k = 0; k < 64; k++) {
m0[word * 64 + k] = (bool)(on_off & (1ULL << k));
}
}
m0 >>= 128;
for (size_t word = 0; word < m0.num_u64_padded(); word++) {
uint64_t pattern = 0ULL;
for (size_t k = 0; k < 64; k++) {
pattern |= (uint64_t{m0[word * 64 + k]} << k);
}
if (word < 6) {
ASSERT_EQ(pattern, 0xAAAAAAAAAAAAAAAA);
} else {
ASSERT_EQ(pattern, 0ULL);
}
}
})

TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_right_shift_assignment, {
auto rng = SHARED_TEST_RNG();
for (int i = 0; i < 5; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits), m2(num_bits);
m1.randomize(num_bits, rng);
m2 = m1;
std::uniform_int_distribution dist_shift(0, (int)m1.num_bits_padded());
size_t shift = dist_shift(rng);
m1 >>= shift;
for (size_t k = 0; k < m1.num_bits_padded() - shift; k++) {
ASSERT_EQ(m1[k], m2[k + shift]);
}
for (size_t k = m1.num_bits_padded() - shift; k < m1.num_bits_padded(); k++) {
ASSERT_EQ(m1[k], 0);
}
}
})

TEST_EACH_WORD_SIZE_W(simd_bits, left_shift_assignment, {
simd_bits<W> m0(512), m1(512);
for (size_t w = 0; w < m0.num_u64_padded(); w++) {
m0.u64[w] = 0xAAAAAAAAAAAAAAAAULL;
}
m0 <<= 1;
m1 = m0;
m1 <<= 0;
for (size_t k = 0; k < 512; k++) {
ASSERT_EQ(m0[k], m1[k]);
}
for (size_t w = 0; w < m0.num_u64_padded(); w++) {
if (w == 0) {
ASSERT_EQ(m0.u64[w], 0x5555555555555554ULL);
} else {
ASSERT_EQ(m0.u64[w], 0x5555555555555555ULL);
}
}
m0 <<= 63;
for (size_t w = 0; w < m0.num_u64_padded(); w++) {
if (w == 0) {
ASSERT_EQ(m0.u64[w], 0ULL);
} else {
ASSERT_EQ(m0.u64[w], 0xAAAAAAAAAAAAAAAAULL);
}
}
m0 <<= 488;
ASSERT_TRUE(!m0.not_zero());
})

TEST_EACH_WORD_SIZE_W(simd_bits, fuzz_left_shift_assignment, {
auto rng = SHARED_TEST_RNG();
for (int i = 0; i < 5; i++) {
std::uniform_int_distribution dist_bits(1, 1200);
int num_bits = dist_bits(rng);
simd_bits<W> m1(num_bits), m2(num_bits);
m1.randomize(num_bits, rng);
m2 = m1;
std::uniform_int_distribution dist_shift(0, (int)m1.num_bits_padded());
size_t shift = dist_shift(rng);
m1 <<= shift;
for (size_t k = 0; k < m1.num_bits_padded() - shift; k++) {
ASSERT_EQ(m1[k + shift], m2[k]);
}
for (size_t k = 0; k < shift; k++) {
ASSERT_EQ(m1[k], 0);
}
}
})

TEST_EACH_WORD_SIZE_W(simd_bits, assignment, {
simd_bits<W> m0(512);
simd_bits<W> m1(512);
Expand Down
5 changes: 5 additions & 0 deletions src/stim/mem/simd_bits_range_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ struct simd_bits_range_ref {
/// Mask assignment.
simd_bits_range_ref operator&=(const simd_bits_range_ref other);
simd_bits_range_ref operator|=(const simd_bits_range_ref other);
// Addition assigment
simd_bits_range_ref operator+=(const simd_bits_range_ref<W> other);
// Shift assigment
simd_bits_range_ref operator>>=(int offset);
simd_bits_range_ref operator<<=(int offset);
/// Swap assignment.
void swap_with(simd_bits_range_ref other);

Expand Down
71 changes: 70 additions & 1 deletion src/stim/mem/simd_bits_range_ref.inl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,75 @@ simd_bits_range_ref<W> simd_bits_range_ref<W>::operator=(const simd_bits_range_r
return *this;
}

template <size_t W>
simd_bits_range_ref<W> simd_bits_range_ref<W>::operator+=(const simd_bits_range_ref<W> other) {
size_t num_u64 = num_u64_padded();
for (size_t w = 0; w < num_u64 - 1; w++) {
u64[w] += other.u64[w];
u64[w + 1] += (u64[w] < other.u64[w]);
}
u64[num_u64 - 1] += other.u64[num_u64 - 1];
return *this;
}

template <size_t W>
simd_bits_range_ref<W> simd_bits_range_ref<W>::operator>>=(int offset) {
uint64_t incoming_word;
uint64_t cur_word;
if (offset == 0) {
return *this;
}
while (offset >= 64) {
incoming_word = 0ULL;
for (int w = num_u64_padded() - 1; w >= 0; w--) {
cur_word = u64[w];
u64[w] = incoming_word;
incoming_word = cur_word;
}
offset -= 64;
}
if (offset == 0) {
return *this;
}
incoming_word = 0ULL;
for (int w = num_u64_padded() - 1; w >= 0; w--) {
cur_word = u64[w];
u64[w] >>= offset;
u64[w] |= incoming_word << (64 - offset);
incoming_word = cur_word & ((uint64_t{1} << offset) - 1);
}
return *this;
}

template <size_t W>
simd_bits_range_ref<W> simd_bits_range_ref<W>::operator<<=(int offset) {
uint64_t incoming_word;
uint64_t cur_word;
if (offset == 0) {
return *this;
}
while (offset >= 64) {
incoming_word = 0ULL;
for (int w = 0; w < num_u64_padded(); w++) {
cur_word = u64[w];
u64[w] = incoming_word;
incoming_word = cur_word;
}
offset -= 64;
}
if (offset == 0) {
return *this;
}
incoming_word = 0ULL;
for (int w = 0; w < num_u64_padded(); w++) {
cur_word = u64[w];
u64[w] <<= offset;
u64[w] |= incoming_word;
incoming_word = (cur_word >> (64 - offset));
}
return *this;
}

template <size_t W>
void simd_bits_range_ref<W>::swap_with(simd_bits_range_ref<W> other) {
for_each_word(other, [](bitword<W> &w0, bitword<W> &w1) {
Expand Down Expand Up @@ -153,4 +222,4 @@ bool simd_bits_range_ref<W>::intersects(const simd_bits_range_ref<W> other) cons
return v != 0;
}

}
} // namespace stim
Loading

0 comments on commit d44e4b6

Please sign in to comment.