diff --git a/.spec/math/modular_inverse_spec.lua b/.spec/math/modular_inverse_spec.lua index 08eba24..3222185 100644 --- a/.spec/math/modular_inverse_spec.lua +++ b/.spec/math/modular_inverse_spec.lua @@ -13,7 +13,13 @@ describe("Modular inverse", function() it("should handle cases when inputs are not co-prime", function() assert.equal(nil, modular_inverse(2, 2)) assert.equal(nil, modular_inverse(5, 15)) + end) + + it("should handle cases when modulus is 1", function() + assert.equal(nil, modular_inverse(-1, 1)) assert.equal(nil, modular_inverse(0, 1)) + assert.equal(nil, modular_inverse(1, 1)) + assert.equal(nil, modular_inverse(2, 1)) end) it("should throw error when modulus is zero", function() diff --git a/src/math/modular_inverse.lua b/src/math/modular_inverse.lua index 00cf806..5d81055 100644 --- a/src/math/modular_inverse.lua +++ b/src/math/modular_inverse.lua @@ -8,8 +8,11 @@ return function( m -- modulus ) assert(m > 0, "modulus must be positive") + if m == 1 then + return nil + end local gcd, x, _ = extended_gcd(a % m, m) - if a ~= 0 and gcd == 1 then + if gcd == 1 then -- Ensure that result is in (0, m) return x % m end