Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wrapper: Add modular arithmetic functions to ArbitraryPrecisionInteger #284

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions Sources/CryptoBoringWrapper/Util/ArbitraryPrecisionInteger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,119 @@ extension ArbitraryPrecisionInteger: Numeric {
}
}

// MARK: - Modular arithmetic

extension ArbitraryPrecisionInteger {
@usableFromInline
package func modulo(_ mod: ArbitraryPrecisionInteger, nonNegative: Bool = false) throws -> ArbitraryPrecisionInteger {
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
mod.withUnsafeBignumPointer { modPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
if nonNegative {
CCryptoBoringSSL_BN_nnmod(resultPtr, selfPtr, modPtr, bnCtx)
} else {
CCryptoBoringSSLShims_BN_mod(resultPtr, selfPtr, modPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}

@usableFromInline
package func inverse(modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger {
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
mod.withUnsafeBignumPointer { modPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
CCryptoBoringSSL_BN_mod_inverse(resultPtr, selfPtr, modPtr, bnCtx)
}
}
}
}
guard rc != nil else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}


@usableFromInline
package static func inverse(lhs: ArbitraryPrecisionInteger, modulo mod: ArbitraryPrecisionInteger) throws -> ArbitraryPrecisionInteger {
try ArbitraryPrecisionInteger(lhs).inverse(modulo: mod)
}

@usableFromInline
package func add(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
guard let modulus else { return self + rhs }
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
rhs.withUnsafeBignumPointer { rhsPtr in
modulus.withUnsafeBignumPointer { modulusPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
return CCryptoBoringSSL_BN_mod_add(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}

@usableFromInline
package func sub(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
guard let modulus else { return self - rhs }
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
rhs.withUnsafeBignumPointer { rhsPtr in
modulus.withUnsafeBignumPointer { modulusPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
CCryptoBoringSSL_BN_mod_sub(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}

@usableFromInline
package func mul(_ rhs: ArbitraryPrecisionInteger, modulo modulus: ArbitraryPrecisionInteger? = nil) throws -> ArbitraryPrecisionInteger {
guard let modulus else { return self * rhs }
var result = ArbitraryPrecisionInteger()

let rc = result.withUnsafeMutableBignumPointer { resultPtr in
self.withUnsafeBignumPointer { selfPtr in
rhs.withUnsafeBignumPointer { rhsPtr in
modulus.withUnsafeBignumPointer { modulusPtr in
ArbitraryPrecisionInteger.withUnsafeBN_CTX { bnCtx in
return CCryptoBoringSSL_BN_mod_mul(resultPtr, selfPtr, rhsPtr, modulusPtr, bnCtx)
}
}
}
}
}
guard rc == 1 else { throw CryptoBoringWrapperError.internalBoringSSLError() }

return result
}
}

// MARK: - SignedNumeric

extension ArbitraryPrecisionInteger: SignedNumeric {
Expand Down
118 changes: 118 additions & 0 deletions Tests/CryptoBoringWrapperTests/ArbitraryPrecisionIntegerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,122 @@ final class ArbitraryPrecisionIntegerTests: XCTestCase {
XCTAssertEqual(try ArbitraryPrecisionInteger(bytes: bytes), integer)
}
}

func testMoudlo() throws {
typealias I = ArbitraryPrecisionInteger
typealias Vector = (input: I, mod: I, expectedResult: (standard: I, nonNegative: I))
for vector: Vector in [
(input: 0, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
(input: 1, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
(input: 2, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
(input: 3, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
(input: 4, mod: 2, expectedResult: (standard: 0, nonNegative: 0)),
(input: 5, mod: 2, expectedResult: (standard: 1, nonNegative: 1)),
(input: 7, mod: 5, expectedResult: (standard: 2, nonNegative: 2)),
(input: 7, mod: -5, expectedResult: (standard: 2, nonNegative: 2)),
(input: -7, mod: 5, expectedResult: (standard: -2, nonNegative: 3)),
(input: -7, mod: -5, expectedResult: (standard: -2, nonNegative: 3)),
] {
XCTAssertEqual(
try vector.input.modulo(vector.mod, nonNegative: false),
vector.expectedResult.standard,
"\(vector.input) (mod \(vector.mod))"
)
XCTAssertEqual(
try vector.input.modulo(vector.mod, nonNegative: true),
vector.expectedResult.nonNegative,
"\(vector.input) (nnmod \(vector.mod))"
)
}
}

func testModularInverse() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 3, mod: 7, expectedResult: .ok(5)),
(a: 10, mod: 17, expectedResult: .ok(12)),
(a: 7, mod: 26, expectedResult: .ok(15)),
(a: 7, mod: 7, expectedResult: .throwsError),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.inverse(modulo: vector.mod), expectedValue, "inverse(\(vector.a), modulo: \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.inverse(modulo: vector.mod), "inverse(\(vector.a), modulo: \(vector.mod)")
}
}
}

func testModularAddition() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 0, mod: 2, expectedResult: .ok(1)),
(a: 0, b: 1, mod: 2, expectedResult: .ok(1)),
(a: 1, b: 1, mod: 2, expectedResult: .ok(0)),
(a: 4, b: 3, mod: 5, expectedResult: .ok(2)),
(a: 4, b: 3, mod: -5, expectedResult: .ok(2)),
(a: -4, b: -3, mod: 5, expectedResult: .ok(3)),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.add(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) + \(vector.b) (mod \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.add(vector.b, modulo: vector.mod), "\(vector.a) + \(vector.b) (mod \(vector.mod))")
}
}
}

func testModularSubtraction() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 0, mod: 2, expectedResult: .ok(1)),
(a: 0, b: 1, mod: 2, expectedResult: .ok(1)),
(a: 1, b: 1, mod: 2, expectedResult: .ok(0)),
(a: 4, b: 3, mod: 5, expectedResult: .ok(1)),
(a: 3, b: 4, mod: 5, expectedResult: .ok(4)),
(a: 3, b: 4, mod: -5, expectedResult: .ok(4)),
(a: -3, b: 4, mod: 5, expectedResult: .ok(3)),
(a: 3, b: -4, mod: 5, expectedResult: .ok(2)),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.sub(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) - \(vector.b) (mod \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.sub(vector.b, modulo: vector.mod), "\(vector.a) - \(vector.b) (mod \(vector.mod))")
}
}
}

func testModularMultiplication() throws {
typealias I = ArbitraryPrecisionInteger
enum O { case ok(I), throwsError }
typealias Vector = (a: I, b: I, mod: I, expectedResult: O)
for vector: Vector in [
(a: 0, b: 0, mod: 0, expectedResult: .throwsError),
(a: 0, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 0, mod: 2, expectedResult: .ok(0)),
(a: 0, b: 1, mod: 2, expectedResult: .ok(0)),
(a: 1, b: 1, mod: 2, expectedResult: .ok(1)),
(a: 4, b: 3, mod: 5, expectedResult: .ok(2)),
(a: 4, b: 3, mod: -5, expectedResult: .ok(2)),
(a: -4, b: -3, mod: 5, expectedResult: .ok(2)),
] {
switch vector.expectedResult {
case .ok(let expectedValue):
XCTAssertEqual(try vector.a.mul(vector.b, modulo: vector.mod), expectedValue, "\(vector.a) × \(vector.b) (mod \(vector.mod))")
case .throwsError:
XCTAssertThrowsError(try vector.a.mul(vector.b, modulo: vector.mod), "\(vector.a) × \(vector.b) (mod \(vector.mod))")
}
}
}
}
Loading