Skip to content

Commit

Permalink
Felt252: implement shl and shr logics (keep-starknet-strange#96)
Browse files Browse the repository at this point in the history
* add overflowing_shl logic and unit tests

* add wrapping_shl function and unit tests

* add saturating_shl function and unit tests

* add checked_shl function and unit tests

* add shr logic

* clean up

* fix overflowing_shr

* test

* back

---------

Co-authored-by: Abdel @ StarkWare <[email protected]>
  • Loading branch information
tcoratger and AbdelStark authored Nov 3, 2023
1 parent 280b016 commit 6f683b4
Show file tree
Hide file tree
Showing 2 changed files with 689 additions and 1 deletion.
312 changes: 311 additions & 1 deletion src/math/fields/fields.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,32 @@ pub fn Field(
comptime mod: u256,
) type {
return struct {
const Self = @This();

/// Number of bits needed to represent a field element with the given modulo.
pub const BitSize = @bitSizeOf(u256) - @clz(mod);
/// Number of bytes required to store a field element.
pub const BytesSize = @sizeOf(u256);
/// The modulo value representing the finite field.
pub const Modulo = mod;
/// Half of the modulo value (Modulo - 1) divided by 2.
pub const QMinOneDiv2 = (Modulo - 1) / 2;
/// The number of bits in each limb (typically 64 for u64).
pub const Bits: usize = 64;
/// Bit mask for the last limb.
pub const Mask: u64 = mask(Bits);
/// Number of limbs used to represent a field element.
pub const Limbs: usize = 4;
/// The smallest value that can be represented by this integer type.
pub const Min = Self.zero();
/// The largest value that can be represented by this integer type.
pub const Max: Self = .{ .fe = .{
std.math.maxInt(u64),
std.math.maxInt(u64),
std.math.maxInt(u64),
std.math.maxInt(u64),
} };

const Self = @This();
const base_zero = val: {
var bz: F.MontgomeryDomainFieldElement = undefined;
F.fromBytes(
Expand All @@ -24,6 +44,19 @@ pub fn Field(

fe: F.MontgomeryDomainFieldElement,

/// Mask to apply to the highest limb to get the correct number of bits.
pub fn mask(bits: usize) u64 {
if (bits == 0) {
return 0;
}
const _bits = @mod(bits, 64);
if (_bits == 0) {
return std.math.maxInt(u64);
} else {
return std.math.shl(u64, 1, _bits) - 1;
}
}

/// Create a field element from an integer in Montgomery representation.
///
/// This function converts an integer to a field element in Montgomery form.
Expand Down Expand Up @@ -504,5 +537,282 @@ pub fn Field(
else => false,
};
}

/// Left shift by `rhs` bits with overflow detection.
///
/// This function shifts the value left by `rhs` bits and detects overflow.
/// It returns the result of the shift and a boolean indicating whether overflow occurred.
///
/// If the product $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$ is greater than or equal to 2^BITS, it returns true.
/// In other words, it returns true if the bits shifted out are non-zero.
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift left.
///
/// # Returns
///
/// A tuple containing the shifted value and a boolean indicating overflow.
pub fn overflowing_shl(
self: Self,
rhs: usize,
) std.meta.Tuple(&.{ Self, bool }) {
const limbs = rhs / 64;
const bits = @mod(rhs, 64);

if (limbs >= Limbs) {
return .{
Self.zero(),
!self.equal(Self.zero()),
};
}
var res = self;
if (bits == 0) {
// Check for overflow
var overflow = false;
for (Limbs - limbs..Limbs) |i| {
overflow = overflow or (res.fe[i] != 0);
}
if (res.fe[Limbs - limbs - 1] > Self.Mask) {
overflow = true;
}

// Shift
var idx = Limbs - 1;
while (idx >= limbs) : (idx -= 1) {
res.fe[idx] = res.fe[idx - limbs];
}
for (0..limbs) |i| {
res.fe[i] = 0;
}
res.fe[Limbs - 1] &= Self.Mask;
return .{ res, overflow };
}

// Check for overflow
var overflow = false;
for (Limbs - limbs..Limbs) |i| {
overflow = overflow or (res.fe[i] != 0);
}

if (std.math.shr(
u64,
res.fe[Limbs - limbs - 1],
64 - bits,
) != 0) {
overflow = true;
}
if (std.math.shl(
u64,
res.fe[Limbs - limbs - 1],
bits,
) > Self.Mask) {
overflow = true;
}

// Shift
var idx = Limbs - 1;
while (idx > limbs) : (idx -= 1) {
res.fe[idx] = std.math.shl(
u64,
res.fe[idx - limbs],
bits,
) | std.math.shr(
u64,
res.fe[idx - limbs - 1],
64 - bits,
);
}

res.fe[limbs] = std.math.shl(
u64,
res.fe[0],
bits,
);
for (0..limbs) |i| {
res.fe[i] = 0;
}
res.fe[Limbs - 1] &= Self.Mask;
return .{ res, overflow };
}

/// Left shift by `rhs` bits with wrapping behavior.
///
/// This function shifts the value left by `rhs` bits, and it wraps around if an overflow occurs.
/// It returns the result of the shift.
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift left.
///
/// # Returns
///
/// The shifted value with wrapping behavior.
pub fn wrapping_shl(self: Self, rhs: usize) Self {
return self.overflowing_shl(rhs)[0];
}

/// Left shift by `rhs` bits with saturation.
///
/// This function shifts the value left by `rhs` bits with saturation behavior.
/// If an overflow occurs, it returns `Self.Max`, otherwise, it returns the result of the shift.
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift left.
///
/// # Returns
///
/// The shifted value with saturation behavior, or `Self.Max` on overflow.
pub fn saturating_shl(self: Self, rhs: usize) Self {
const _shl = self.overflowing_shl(rhs);
return switch (_shl[1]) {
false => _shl[0],
else => Self.Max,
};
}

/// Checked left shift by `rhs` bits.
///
/// This function performs a left shift of `self` by `rhs` bits. It returns `Some(value)` if the result is less than `2^BITS`, where `value` is the shifted result. If the result
/// would be greater than or equal to `2^BITS`, it returns [`null`], indicating an overflow condition where the shifted-out bits would be non-zero.
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift left.
///
/// # Returns
///
/// - `Some(value)`: The shifted value if no overflow occurs.
/// - [`null`]: If the bits shifted out would be non-zero.
pub fn checked_shl(self: Self, rhs: usize) ?Self {
const _shl = self.overflowing_shl(rhs);
return switch (_shl[1]) {
false => _shl[0],
else => null,
};
}

/// Right shift by `rhs` bits with underflow detection.
///
/// This function performs a right shift of `self` by `rhs` bits. It returns the
/// floor value of the division $\floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}$
/// and a boolean indicating whether the division was exact (false) or rounded down (true).
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift right.
///
/// # Returns
///
/// A tuple containing the shifted value and a boolean indicating underflow.
pub fn overflowing_shr(
self: Self,
rhs: usize,
) std.meta.Tuple(&.{ Self, bool }) {
const limbs = rhs / 64;
const bits = @mod(rhs, 64);

if (limbs >= Limbs) {
return .{
Self.zero(),
!self.equal(Self.zero()),
};
}

var res = self;
if (bits == 0) {
// Check for overflow
var overflow = false;
for (0..limbs) |i| {
overflow = overflow or (res.fe[i] != 0);
}

// Shift
for (0..Limbs - limbs) |i| {
res.fe[i] = res.fe[i + limbs];
}
for (Limbs - limbs..Limbs) |i| {
res.fe[i] = 0;
}
return .{ res, overflow };
}

// Check for overflow
var overflow = false;
for (0..limbs) |i| {
overflow = overflow or (res.fe[i] != 0);
}
overflow = overflow or (std.math.shr(
u64,
res.fe[limbs],
bits,
) != 0);

// Shift
for (0..Limbs - limbs - 1) |i| {
res.fe[i] = std.math.shr(
u64,
res.fe[i + limbs],
bits,
) | std.math.shl(
u64,
res.fe[i + limbs + 1],
64 - bits,
);
}

res.fe[Limbs - limbs - 1] = std.math.shr(
u64,
res.fe[Limbs - 1],
bits,
);
for (Limbs - limbs..Limbs) |i| {
res.fe[i] = 0;
}
return .{ res, overflow };
}

/// Right shift by `rhs` bits with checked underflow.
///
/// This function performs a right shift of `self` by `rhs` bits. It returns `Some(value)` with the result of the shift if no underflow occurs. If underflow happens (bits are shifted out), it returns [`null`].
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift right.
///
/// # Returns
///
/// - `Some(value)`: The shifted value if no underflow occurs.
/// - [`null`]: If the division is not exact.
pub fn checked_shr(self: Self, rhs: usize) ?Self {
const _shl = self.overflowing_shr(rhs);
return switch (_shl[1]) {
false => _shl[0],
else => null,
};
}

/// Right shift by `rhs` bits with wrapping behavior.
///
/// This function performs a right shift of `self` by `rhs` bits, and it wraps around if an underflow occurs. It returns the result of the shift.
///
/// # Parameters
///
/// - `self`: The value to be shifted.
/// - `rhs`: The number of bits to shift right.
///
/// # Returns
///
/// The shifted value with wrapping behavior.
pub fn wrapping_shr(self: Self, rhs: usize) Self {
return self.overflowing_shr(rhs)[0];
}
};
}
Loading

0 comments on commit 6f683b4

Please sign in to comment.