From 4324b40570bffc20d695d4aaeb5fd5fd5c6b9670 Mon Sep 17 00:00:00 2001 From: Si Beaumont Date: Thu, 17 Oct 2024 13:23:48 +0100 Subject: [PATCH] wrapper: Add modular arithmetic functions to ArbitraryPrecisionInteger (#284) --- .../Util/ArbitraryPrecisionInteger.swift | 113 +++++++++++++++++ .../ArbitraryPrecisionIntegerTests.swift | 118 ++++++++++++++++++ 2 files changed, 231 insertions(+) diff --git a/Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift b/Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift index 6b812816..b87efca8 100644 --- a/Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift +++ b/Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift @@ -386,6 +386,119 @@ extension ArbitraryPrecisionInteger: Numeric { } } +// MARK: - Modular arithmetic + +extension ArbitraryPrecisionInteger { + @usableFromInline + package func modulo(_ mod: ArbitraryPrecisionInteger, nonNegative: Bool = false) throws -> ArbitraryPrecisionInteger { + var result = ArbitraryPrecisionInteger() + + let rc = result.withUnsafeMutableBignumPointer { resultPtr in + self.withUnsafeBignumPointer { selfPtr in + mod.withUnsafeBignumPointer { modPtr in + ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in + if nonNegative { + CCryptoBoringSSL_BN_nnmod(resultPtr, selfPtr, modPtr, bnCtx) + } else { + CCryptoBoringSSLShims_BN_mod(resultPtr, selfPtr, modPtr, bnCtx) + } + } + } + } + } + guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() } + + return result + } + + @usableFromInline + package func inverse(modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger { + var result = ArbitraryPrecisionInteger() + + let rc = result.withUnsafeMutableBignumPointer { resultPtr in + self.withUnsafeBignumPointer { selfPtr in + mod.withUnsafeBignumPointer { modPtr in + ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in + CCryptoBoringSSL_BN_mod_inverse(resultPtr, selfPtr, modPtr, bnCtx) + } + } + } + } + guard rc != nil else { throw CryptoBoringWrapperError.internalBoringSSLError() } + + return result + } + + + @usableFromInline + package static func inverse(lhs: ArbitraryPrecisionInteger, modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger { + try ArbitraryPrecisionInteger(lhs).inverse(modulo: mod) + } + + @usableFromInline + package func add(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger { + guard let modulus else { return self + rhs } + var result = ArbitraryPrecisionInteger() + + let rc = result.withUnsafeMutableBignumPointer { resultPtr in + self.withUnsafeBignumPointer { selfPtr in + rhs.withUnsafeBignumPointer { rhsPtr in + modulus.withUnsafeBignumPointer { modulusPtr in + ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in + return CCryptoBoringSSL_BN_mod_add(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx) + } + } + } + } + } + guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() } + + return result + } + + @usableFromInline + package func sub(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger { + guard let modulus else { return self - rhs } + var result = ArbitraryPrecisionInteger() + + let rc = result.withUnsafeMutableBignumPointer { resultPtr in + self.withUnsafeBignumPointer { selfPtr in + rhs.withUnsafeBignumPointer { rhsPtr in + modulus.withUnsafeBignumPointer { modulusPtr in + ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in + CCryptoBoringSSL_BN_mod_sub(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx) + } + } + } + } + } + guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() } + + return result + } + + @usableFromInline + package func mul(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger { + guard let modulus else { return self * rhs } + var result = ArbitraryPrecisionInteger() + + let rc = result.withUnsafeMutableBignumPointer { resultPtr in + self.withUnsafeBignumPointer { selfPtr in + rhs.withUnsafeBignumPointer { rhsPtr in + modulus.withUnsafeBignumPointer { modulusPtr in + ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in + return CCryptoBoringSSL_BN_mod_mul(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx) + } + } + } + } + } + guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() } + + return result + } +} + // MARK: - SignedNumeric extension ArbitraryPrecisionInteger: SignedNumeric { diff --git a/Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift b/Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift index e96c7bcc..2957e8d3 100644 --- a/Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift +++ b/Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift @@ -165,4 +165,122 @@ final class ArbitraryPrecisionIntegerTests: XCTestCase { XCTAssertEqual(try ArbitraryPrecisionInteger(bytes: bytes), integer) } } + + func testMoudlo() throws { + typealias I = ArbitraryPrecisionInteger + typealias Vector = (input: I, mod: I, expectedResult: (standard: I, nonNegative: I)) + for vector: Vector in [ + (input: 0, mod: 2, expectedResult: (standard: 0, nonNegative: 0)), + (input: 1, mod: 2, expectedResult: (standard: 1, nonNegative: 1)), + (input: 2, mod: 2, expectedResult: (standard: 0, nonNegative: 0)), + (input: 3, mod: 2, expectedResult: (standard: 1, nonNegative: 1)), + (input: 4, mod: 2, expectedResult: (standard: 0, nonNegative: 0)), + (input: 5, mod: 2, expectedResult: (standard: 1, nonNegative: 1)), + (input: 7, mod: 5, expectedResult: (standard: 2, nonNegative: 2)), + (input: 7, mod: -5, expectedResult: (standard: 2, nonNegative: 2)), + (input: -7, mod: 5, expectedResult: (standard: -2, nonNegative: 3)), + (input: -7, mod: -5, expectedResult: (standard: -2, nonNegative: 3)), + ] { + XCTAssertEqual( + try vector.input.modulo(vector.mod, nonNegative: false), + vector.expectedResult.standard, + "\(vector.input) (mod \(vector.mod))" + ) + XCTAssertEqual( + try vector.input.modulo(vector.mod, nonNegative: true), + vector.expectedResult.nonNegative, + "\(vector.input) (nnmod \(vector.mod))" + ) + } + } + + func testModularInverse() throws { + typealias I = ArbitraryPrecisionInteger + enum O { case ok(I), throwsError } + typealias Vector = (a: I, mod: I, expectedResult: O) + for vector: Vector in [ + (a: 3, mod: 7, expectedResult: .ok(5)), + (a: 10, mod: 17, expectedResult: .ok(12)), + (a: 7, mod: 26, expectedResult: .ok(15)), + (a: 7, mod: 7, expectedResult: .throwsError), + ] { + switch vector.expectedResult { + case .ok(let expectedValue): + XCTAssertEqual(try vector.a.inverse(modulo: vector.mod), expectedValue, "inverse(\(vector.a), modulo: \(vector.mod))") + case .throwsError: + XCTAssertThrowsError(try vector.a.inverse(modulo: vector.mod), "inverse(\(vector.a), modulo: \(vector.mod)") + } + } + } + + func testModularAddition() throws { + typealias I = ArbitraryPrecisionInteger + enum O { case ok(I), throwsError } + typealias Vector = (a: I, b: I, mod: I, expectedResult: O) + for vector: Vector in [ + (a: 0, b: 0, mod: 0, expectedResult: .throwsError), + (a: 0, b: 0, mod: 2, expectedResult: .ok(0)), + (a: 1, b: 0, mod: 2, expectedResult: .ok(1)), + (a: 0, b: 1, mod: 2, expectedResult: .ok(1)), + (a: 1, b: 1, mod: 2, expectedResult: .ok(0)), + (a: 4, b: 3, mod: 5, expectedResult: .ok(2)), + (a: 4, b: 3, mod: -5, expectedResult: .ok(2)), + (a: -4, b: -3, mod: 5, expectedResult: .ok(3)), + ] { + switch vector.expectedResult { + case .ok(let expectedValue): + XCTAssertEqual(try vector.a.add(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) + \(vector.b) (mod \(vector.mod))") + case .throwsError: + XCTAssertThrowsError(try vector.a.add(vector.b, modulo: vector.mod), "\(vector.a) + \(vector.b) (mod \(vector.mod))") + } + } + } + + func testModularSubtraction() throws { + typealias I = ArbitraryPrecisionInteger + enum O { case ok(I), throwsError } + typealias Vector = (a: I, b: I, mod: I, expectedResult: O) + for vector: Vector in [ + (a: 0, b: 0, mod: 0, expectedResult: .throwsError), + (a: 0, b: 0, mod: 2, expectedResult: .ok(0)), + (a: 1, b: 0, mod: 2, expectedResult: .ok(1)), + (a: 0, b: 1, mod: 2, expectedResult: .ok(1)), + (a: 1, b: 1, mod: 2, expectedResult: .ok(0)), + (a: 4, b: 3, mod: 5, expectedResult: .ok(1)), + (a: 3, b: 4, mod: 5, expectedResult: .ok(4)), + (a: 3, b: 4, mod: -5, expectedResult: .ok(4)), + (a: -3, b: 4, mod: 5, expectedResult: .ok(3)), + (a: 3, b: -4, mod: 5, expectedResult: .ok(2)), + ] { + switch vector.expectedResult { + case .ok(let expectedValue): + XCTAssertEqual(try vector.a.sub(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) - \(vector.b) (mod \(vector.mod))") + case .throwsError: + XCTAssertThrowsError(try vector.a.sub(vector.b, modulo: vector.mod), "\(vector.a) - \(vector.b) (mod \(vector.mod))") + } + } + } + + func testModularMultiplication() throws { + typealias I = ArbitraryPrecisionInteger + enum O { case ok(I), throwsError } + typealias Vector = (a: I, b: I, mod: I, expectedResult: O) + for vector: Vector in [ + (a: 0, b: 0, mod: 0, expectedResult: .throwsError), + (a: 0, b: 0, mod: 2, expectedResult: .ok(0)), + (a: 1, b: 0, mod: 2, expectedResult: .ok(0)), + (a: 0, b: 1, mod: 2, expectedResult: .ok(0)), + (a: 1, b: 1, mod: 2, expectedResult: .ok(1)), + (a: 4, b: 3, mod: 5, expectedResult: .ok(2)), + (a: 4, b: 3, mod: -5, expectedResult: .ok(2)), + (a: -4, b: -3, mod: 5, expectedResult: .ok(2)), + ] { + switch vector.expectedResult { + case .ok(let expectedValue): + XCTAssertEqual(try vector.a.mul(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) × \(vector.b) (mod \(vector.mod))") + case .throwsError: + XCTAssertThrowsError(try vector.a.mul(vector.b, modulo: vector.mod), "\(vector.a) × \(vector.b) (mod \(vector.mod))") + } + } + } }