From 6f683b4c98c75b5cf4ce36292f12ac3457619968 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Fri, 3 Nov 2023 11:11:46 +0100 Subject: [PATCH] `Felt252`: implement `shl` and `shr` logics (#96) * add overflowing_shl logic and unit tests * add wrapping_shl function and unit tests * add saturating_shl function and unit tests * add checked_shl function and unit tests * add shr logic * clean up * fix overflowing_shr * test * back --------- Co-authored-by: Abdel @ StarkWare <45264458+abdelhamidbakhta@users.noreply.github.com> --- src/math/fields/fields.zig | 312 ++++++++++++++++++++++++++++- src/math/fields/starknet.zig | 378 +++++++++++++++++++++++++++++++++++ 2 files changed, 689 insertions(+), 1 deletion(-) diff --git a/src/math/fields/fields.zig b/src/math/fields/fields.zig index e6fd3184..d749bf35 100644 --- a/src/math/fields/fields.zig +++ b/src/math/fields/fields.zig @@ -7,12 +7,32 @@ pub fn Field( comptime mod: u256, ) type { return struct { + const Self = @This(); + + /// Number of bits needed to represent a field element with the given modulo. pub const BitSize = @bitSizeOf(u256) - @clz(mod); + /// Number of bytes required to store a field element. pub const BytesSize = @sizeOf(u256); + /// The modulo value representing the finite field. pub const Modulo = mod; + /// Half of the modulo value (Modulo - 1) divided by 2. pub const QMinOneDiv2 = (Modulo - 1) / 2; + /// The number of bits in each limb (typically 64 for u64). + pub const Bits: usize = 64; + /// Bit mask for the last limb. + pub const Mask: u64 = mask(Bits); + /// Number of limbs used to represent a field element. + pub const Limbs: usize = 4; + /// The smallest value that can be represented by this integer type. + pub const Min = Self.zero(); + /// The largest value that can be represented by this integer type. + pub const Max: Self = .{ .fe = .{ + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + } }; - const Self = @This(); const base_zero = val: { var bz: F.MontgomeryDomainFieldElement = undefined; F.fromBytes( @@ -24,6 +44,19 @@ pub fn Field( fe: F.MontgomeryDomainFieldElement, + /// Mask to apply to the highest limb to get the correct number of bits. + pub fn mask(bits: usize) u64 { + if (bits == 0) { + return 0; + } + const _bits = @mod(bits, 64); + if (_bits == 0) { + return std.math.maxInt(u64); + } else { + return std.math.shl(u64, 1, _bits) - 1; + } + } + /// Create a field element from an integer in Montgomery representation. /// /// This function converts an integer to a field element in Montgomery form. @@ -504,5 +537,282 @@ pub fn Field( else => false, }; } + + /// Left shift by `rhs` bits with overflow detection. + /// + /// This function shifts the value left by `rhs` bits and detects overflow. + /// It returns the result of the shift and a boolean indicating whether overflow occurred. + /// + /// If the product $\mod{\mathtt{value} ⋅ 2^{\mathtt{rhs}}}_{2^{\mathtt{BITS}}}$ is greater than or equal to 2^BITS, it returns true. + /// In other words, it returns true if the bits shifted out are non-zero. + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift left. + /// + /// # Returns + /// + /// A tuple containing the shifted value and a boolean indicating overflow. + pub fn overflowing_shl( + self: Self, + rhs: usize, + ) std.meta.Tuple(&.{ Self, bool }) { + const limbs = rhs / 64; + const bits = @mod(rhs, 64); + + if (limbs >= Limbs) { + return .{ + Self.zero(), + !self.equal(Self.zero()), + }; + } + var res = self; + if (bits == 0) { + // Check for overflow + var overflow = false; + for (Limbs - limbs..Limbs) |i| { + overflow = overflow or (res.fe[i] != 0); + } + if (res.fe[Limbs - limbs - 1] > Self.Mask) { + overflow = true; + } + + // Shift + var idx = Limbs - 1; + while (idx >= limbs) : (idx -= 1) { + res.fe[idx] = res.fe[idx - limbs]; + } + for (0..limbs) |i| { + res.fe[i] = 0; + } + res.fe[Limbs - 1] &= Self.Mask; + return .{ res, overflow }; + } + + // Check for overflow + var overflow = false; + for (Limbs - limbs..Limbs) |i| { + overflow = overflow or (res.fe[i] != 0); + } + + if (std.math.shr( + u64, + res.fe[Limbs - limbs - 1], + 64 - bits, + ) != 0) { + overflow = true; + } + if (std.math.shl( + u64, + res.fe[Limbs - limbs - 1], + bits, + ) > Self.Mask) { + overflow = true; + } + + // Shift + var idx = Limbs - 1; + while (idx > limbs) : (idx -= 1) { + res.fe[idx] = std.math.shl( + u64, + res.fe[idx - limbs], + bits, + ) | std.math.shr( + u64, + res.fe[idx - limbs - 1], + 64 - bits, + ); + } + + res.fe[limbs] = std.math.shl( + u64, + res.fe[0], + bits, + ); + for (0..limbs) |i| { + res.fe[i] = 0; + } + res.fe[Limbs - 1] &= Self.Mask; + return .{ res, overflow }; + } + + /// Left shift by `rhs` bits with wrapping behavior. + /// + /// This function shifts the value left by `rhs` bits, and it wraps around if an overflow occurs. + /// It returns the result of the shift. + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift left. + /// + /// # Returns + /// + /// The shifted value with wrapping behavior. + pub fn wrapping_shl(self: Self, rhs: usize) Self { + return self.overflowing_shl(rhs)[0]; + } + + /// Left shift by `rhs` bits with saturation. + /// + /// This function shifts the value left by `rhs` bits with saturation behavior. + /// If an overflow occurs, it returns `Self.Max`, otherwise, it returns the result of the shift. + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift left. + /// + /// # Returns + /// + /// The shifted value with saturation behavior, or `Self.Max` on overflow. + pub fn saturating_shl(self: Self, rhs: usize) Self { + const _shl = self.overflowing_shl(rhs); + return switch (_shl[1]) { + false => _shl[0], + else => Self.Max, + }; + } + + /// Checked left shift by `rhs` bits. + /// + /// This function performs a left shift of `self` by `rhs` bits. It returns `Some(value)` if the result is less than `2^BITS`, where `value` is the shifted result. If the result + /// would be greater than or equal to `2^BITS`, it returns [`null`], indicating an overflow condition where the shifted-out bits would be non-zero. + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift left. + /// + /// # Returns + /// + /// - `Some(value)`: The shifted value if no overflow occurs. + /// - [`null`]: If the bits shifted out would be non-zero. + pub fn checked_shl(self: Self, rhs: usize) ?Self { + const _shl = self.overflowing_shl(rhs); + return switch (_shl[1]) { + false => _shl[0], + else => null, + }; + } + + /// Right shift by `rhs` bits with underflow detection. + /// + /// This function performs a right shift of `self` by `rhs` bits. It returns the + /// floor value of the division $\floor{\frac{\mathtt{self}}{2^{\mathtt{rhs}}}}$ + /// and a boolean indicating whether the division was exact (false) or rounded down (true). + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift right. + /// + /// # Returns + /// + /// A tuple containing the shifted value and a boolean indicating underflow. + pub fn overflowing_shr( + self: Self, + rhs: usize, + ) std.meta.Tuple(&.{ Self, bool }) { + const limbs = rhs / 64; + const bits = @mod(rhs, 64); + + if (limbs >= Limbs) { + return .{ + Self.zero(), + !self.equal(Self.zero()), + }; + } + + var res = self; + if (bits == 0) { + // Check for overflow + var overflow = false; + for (0..limbs) |i| { + overflow = overflow or (res.fe[i] != 0); + } + + // Shift + for (0..Limbs - limbs) |i| { + res.fe[i] = res.fe[i + limbs]; + } + for (Limbs - limbs..Limbs) |i| { + res.fe[i] = 0; + } + return .{ res, overflow }; + } + + // Check for overflow + var overflow = false; + for (0..limbs) |i| { + overflow = overflow or (res.fe[i] != 0); + } + overflow = overflow or (std.math.shr( + u64, + res.fe[limbs], + bits, + ) != 0); + + // Shift + for (0..Limbs - limbs - 1) |i| { + res.fe[i] = std.math.shr( + u64, + res.fe[i + limbs], + bits, + ) | std.math.shl( + u64, + res.fe[i + limbs + 1], + 64 - bits, + ); + } + + res.fe[Limbs - limbs - 1] = std.math.shr( + u64, + res.fe[Limbs - 1], + bits, + ); + for (Limbs - limbs..Limbs) |i| { + res.fe[i] = 0; + } + return .{ res, overflow }; + } + + /// Right shift by `rhs` bits with checked underflow. + /// + /// This function performs a right shift of `self` by `rhs` bits. It returns `Some(value)` with the result of the shift if no underflow occurs. If underflow happens (bits are shifted out), it returns [`null`]. + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift right. + /// + /// # Returns + /// + /// - `Some(value)`: The shifted value if no underflow occurs. + /// - [`null`]: If the division is not exact. + pub fn checked_shr(self: Self, rhs: usize) ?Self { + const _shl = self.overflowing_shr(rhs); + return switch (_shl[1]) { + false => _shl[0], + else => null, + }; + } + + /// Right shift by `rhs` bits with wrapping behavior. + /// + /// This function performs a right shift of `self` by `rhs` bits, and it wraps around if an underflow occurs. It returns the result of the shift. + /// + /// # Parameters + /// + /// - `self`: The value to be shifted. + /// - `rhs`: The number of bits to shift right. + /// + /// # Returns + /// + /// The shifted value with wrapping behavior. + pub fn wrapping_shr(self: Self, rhs: usize) Self { + return self.overflowing_shr(rhs)[0]; + } }; } diff --git a/src/math/fields/starknet.zig b/src/math/fields/starknet.zig index fb80d882..c415776a 100644 --- a/src/math/fields/starknet.zig +++ b/src/math/fields/starknet.zig @@ -529,3 +529,381 @@ test "Felt252 lexographicallyLargest" { ).lexographicallyLargest()); try expect(Felt252.fromInteger(std.math.maxInt(u256)).lexographicallyLargest()); } + +test "Felt252 overflowing_shl" { + var a = Felt252.fromInteger(10); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0xfffffffffffffd82, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xfffffffffffd5a1, + } }, + false, + }, + ), + a.overflowing_shl(1), + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0xffffae6fc0008420, + 0x2661ffffff, + 0xffffffffedf00000, + 0xfffa956bc011461f, + } }, + false, + }, + ), + b.overflowing_shl(5), + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0xfffffeacea720400, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffe97b919243ff, + } }, + true, + }, + ), + c.overflowing_shl(10), + ); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252.zero(), + true, + }, + ), + c.overflowing_shl(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ 0x0, 0x0, 0x0, 0xffffffffc06bf561 } }, + true, + }, + ), + d.overflowing_shl(3 * 64), + ); +} + +test "Felt252 wrapping_shl" { + var a = Felt252.fromInteger(10); + try expectEqual( + Felt252{ .fe = .{ + 0xfffffffffffffd82, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xfffffffffffd5a1, + } }, + a.wrapping_shl(1), + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + Felt252{ .fe = .{ + 0xffffae6fc0008420, + 0x2661ffffff, + 0xffffffffedf00000, + 0xfffa956bc011461f, + } }, + b.wrapping_shl(5), + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + Felt252{ .fe = .{ + 0xfffffeacea720400, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffe97b919243ff, + } }, + c.wrapping_shl(10), + ); + try expectEqual( + Felt252.zero(), + c.wrapping_shl(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + Felt252{ .fe = .{ 0x0, 0x0, 0x0, 0xffffffffc06bf561 } }, + d.wrapping_shl(3 * 64), + ); +} + +test "Felt252 saturating_shl" { + var a = Felt252.fromInteger(10); + try expectEqual( + Felt252{ .fe = .{ + 0xfffffffffffffd82, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xfffffffffffd5a1, + } }, + a.saturating_shl(1), + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + Felt252{ .fe = .{ + 0xffffae6fc0008420, + 0x2661ffffff, + 0xffffffffedf00000, + 0xfffa956bc011461f, + } }, + b.saturating_shl(5), + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + Felt252{ .fe = .{ + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + } }, + c.saturating_shl(10), + ); + try expectEqual( + Felt252{ .fe = .{ + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + } }, + c.saturating_shl(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + Felt252{ .fe = .{ + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + std.math.maxInt(u64), + } }, + d.saturating_shl(3 * 64), + ); +} + +test "Felt252 checked_shl" { + var a = Felt252.fromInteger(10); + try expectEqual( + Felt252{ .fe = .{ + 0xfffffffffffffd82, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xfffffffffffd5a1, + } }, + a.checked_shl(1).?, + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + Felt252{ .fe = .{ + 0xffffae6fc0008420, + 0x2661ffffff, + 0xffffffffedf00000, + 0xfffa956bc011461f, + } }, + b.checked_shl(5).?, + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + @as(?Felt252, null), + c.checked_shl(10), + ); + try expectEqual( + @as(?Felt252, null), + c.checked_shl(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + @as(?Felt252, null), + d.checked_shl(3 * 64), + ); +} + +test "Felt252 overflowing_shr" { + var a = Felt252.fromInteger(10); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0xffffffffffffff60, + 0xffffffffffffffff, + 0x7fffffffffffffff, + 0x3fffffffffff568, + } }, + true, + }, + ), + a.overflowing_shr(1), + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0xffffffeb9bf00021, 0x9987fff, 0x87fffffffffb7c00, 0x3ffea55af00451, + } }, + true, + }, + ), + b.overflowing_shr(5), + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0xffffffffffeacea7, 0xffffffffffffffff, 0x243fffffffffffff, 0x1fffffe97b919, + } }, + true, + }, + ), + c.overflowing_shr(10), + ); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252.zero(), + true, + }, + ), + c.overflowing_shr(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0x7fffffbc72b4b70, + 0x0, + 0x0, + 0x0, + } }, + true, + }, + ), + d.overflowing_shr(3 * 64), + ); + var e = Felt252{ .fe = .{ + 0x0, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0x0, + } }; + try expectEqual( + @as( + std.meta.Tuple(&.{ Felt252, bool }), + .{ + Felt252{ .fe = .{ + 0x8000000000000000, 0xffffffffffffffff, 0x7fffffffffffffff, 0x0, + } }, + false, + }, + ), + e.overflowing_shr(1), + ); +} + +test "Felt252 checked_shr" { + var a = Felt252.fromInteger(10); + try expectEqual( + @as(?Felt252, null), + a.checked_shr(1), + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + @as(?Felt252, null), + b.checked_shr(5), + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + @as(?Felt252, null), + c.checked_shr(10), + ); + try expectEqual( + @as(?Felt252, null), + c.checked_shr(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + @as(?Felt252, null), + d.checked_shr(3 * 64), + ); + var e = Felt252{ .fe = .{ + 0x0, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0x0, + } }; + try expectEqual( + Felt252{ .fe = .{ + 0x8000000000000000, 0xffffffffffffffff, 0x7fffffffffffffff, 0x0, + } }, + e.checked_shr(1).?, + ); +} + +test "Felt252 wrapping_shr" { + var a = Felt252.fromInteger(10); + try expectEqual( + Felt252{ .fe = .{ + 0xffffffffffffff60, + 0xffffffffffffffff, + 0x7fffffffffffffff, + 0x3fffffffffff568, + } }, + a.wrapping_shr(1), + ); + var b = Felt252.fromInteger(std.math.maxInt(u256)); + try expectEqual( + Felt252{ .fe = .{ + 0xffffffeb9bf00021, 0x9987fff, 0x87fffffffffb7c00, 0x3ffea55af00451, + } }, + b.wrapping_shr(5), + ); + var c = Felt252.fromInteger(44444444); + try expectEqual( + Felt252{ .fe = .{ + 0xffffffffffeacea7, 0xffffffffffffffff, 0x243fffffffffffff, 0x1fffffe97b919, + } }, + c.wrapping_shr(10), + ); + try expectEqual( + Felt252.zero(), + c.wrapping_shr(5 * 64), + ); + var d = Felt252.fromInteger(33333333); + try expectEqual( + Felt252{ .fe = .{ + 0x7fffffbc72b4b70, + 0x0, + 0x0, + 0x0, + } }, + d.wrapping_shr(3 * 64), + ); + var e = Felt252{ .fe = .{ + 0x0, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0x0, + } }; + try expectEqual( + Felt252{ .fe = .{ + 0x8000000000000000, 0xffffffffffffffff, 0x7fffffffffffffff, 0x0, + } }, + e.wrapping_shr(1), + ); +}