Skip to content

Commit

Permalink
Merge pull request #63 from GiacomoPope/improve_generic_polynomial_co…
Browse files Browse the repository at this point in the history
…verage
  • Loading branch information
GiacomoPope authored Jul 23, 2024
2 parents 5e34467 + b3133d4 commit ce9404f
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 8 deletions.
6 changes: 0 additions & 6 deletions src/kyber_py/modules/modules_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ def __add__(self, other):
False,
)

def __radd__(self, other):
return self.__add__(other)

def __iadd__(self, other):
self = self + other
return self
Expand All @@ -166,9 +163,6 @@ def __sub__(self, other):
False,
)

def __rsub__(self, other):
return self.__sub__(other)

def __isub__(self, other):
self = self - other
return self
Expand Down
2 changes: 1 addition & 1 deletion src/kyber_py/polynomials/polynomials_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __sub__(self, other):
return self.parent(new_coeffs)

def __rsub__(self, other):
return self.__sub__(other)
return -self.__sub__(other)

def __isub__(self, other):
self = self - other
Expand Down
86 changes: 86 additions & 0 deletions tests/test_module_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,52 @@ def test_random_element(self):
self.assertEqual(type(A[0, 0]), self.R.element)
self.assertEqual(A.dim(), (m, n))

def test_print(self):
s = "Module over the commutative ring: Univariate Polynomial Ring in x over Finite Field of size 11 with modulus x^5 + 1"
self.assertEqual(str(self.M), s)
self.assertEqual(self.M.__repr__(), s)

def test_non_list_error(self):
self.assertRaises(TypeError, lambda: self.M("1"))

def test_non_ring_list_error(self):
one = self.R(1)
self.assertRaises(TypeError, lambda: self.M([one, "2", "3"]))
self.assertRaises(TypeError, lambda: self.M(["2", one, "3"]))
self.assertRaises(
TypeError, lambda: self.M([[one, "2", "3"], [one, "2", "3"]])
)
self.assertRaises(
TypeError, lambda: self.M([["1", one, "3"], [one, "2", "3"]])
)

def test_non_rectangular(self):
one = self.R(1)
self.assertRaises(ValueError, lambda: self.M([[one, one], [one]]))


class TestMatrix(unittest.TestCase):
R = PolynomialRing(11, 5)
R_prime = PolynomialRing(11, 2)
M = Module(R)
M_prime = Module(R)

def test_equality(self):
for _ in range(100):
A = self.M.random_element(2, 2)
B = self.M.random_element(2, 3)

self.assertEqual(A, A)
self.assertNotEqual(A, B)

def test_add_errors(self):
A = self.M.random_element(2, 2)
B = self.M.random_element(2, 3)
A_prime = self.M_prime.random_element(2, 2)

self.assertRaises(TypeError, lambda: A + "B")
self.assertRaises(ValueError, lambda: A + B)
self.assertRaises(TypeError, lambda: A + A_prime)

def test_matrix_add(self):
zero = self.R(0)
Expand All @@ -34,6 +76,19 @@ def test_matrix_add(self):
self.assertEqual(A + B, B + A)
self.assertEqual(A + (B + C), (A + B) + C)

B = C
B += C
self.assertEqual(B, C + C)

def test_sub_errors(self):
A = self.M.random_element(2, 2)
B = self.M.random_element(2, 3)
A_prime = self.M_prime.random_element(2, 2)

self.assertRaises(TypeError, lambda: A - "B")
self.assertRaises(ValueError, lambda: A - B)
self.assertRaises(TypeError, lambda: A - A_prime)

def test_matrix_sub(self):
zero = self.R(0)
Z = self.M([[zero, zero], [zero, zero]])
Expand All @@ -46,6 +101,19 @@ def test_matrix_sub(self):
self.assertEqual(A - B, -(B - A))
self.assertEqual(A - (B - C), (A - B) + C)

B = C
B -= C
self.assertEqual(B, Z)

def test_mul_errors(self):
A = self.M.random_element(2, 2)
B = self.M.random_element(5, 5)
A_prime = self.M_prime.random_element(2, 2)

self.assertRaises(TypeError, lambda: A @ "B")
self.assertRaises(ValueError, lambda: A @ B)
self.assertRaises(TypeError, lambda: A @ A_prime)

def test_matrix_mul_square(self):
zero = self.R(0)
one = self.R(1)
Expand Down Expand Up @@ -84,8 +152,13 @@ def test_matrix_transpose(self):
At = A.transpose()
AAt = A @ At

# Should always be symmetric
self.assertEqual(AAt, AAt.transpose())

# Assert transpose in place works
At.transpose_self()
self.assertEqual(A, At)

def test_matrix_dot(self):
for _ in range(100):
u = [self.R.random_element() for _ in range(5)]
Expand All @@ -96,6 +169,19 @@ def test_matrix_dot(self):
V = self.M.vector(v)

self.assertEqual(dot, U.dot(V))
self.assertRaises(TypeError, lambda: U.dot("A"))

def test_print(self):
A = self.M(
[self.R([1, 2]), self.R([3, 4, 5, 6])],
[self.R([0, 0, 0, 0, 3]), self.R([0, 1, 0, 3])],
)
u = self.M([self.R([1, 2]), self.R([3, 4, 5, 6])])

sA = "[ 1 + 2*x]\n[3 + 4*x + 5*x^2 + 6*x^3]"
su = "[1 + 2*x, 3 + 4*x + 5*x^2 + 6*x^3]"
self.assertEqual(str(A), sA)
self.assertEqual(str(u), su)


if __name__ == "__main__":
Expand Down
73 changes: 72 additions & 1 deletion tests/test_polynomial_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,27 @@ def test_random_element(self):
self.assertEqual(len(f.coeffs), self.R.n)
self.assertTrue(all([c < self.R.q for c in f.coeffs]))

def test_non_list_error(self):
self.assertRaises(TypeError, lambda: self.R("1"))

def test_long_list_error(self):
self.assertRaises(ValueError, lambda: self.R([0] * (self.R.n + 1)))

def test_string_format(self):
self.assertEqual(
str(self.R),
"Univariate Polynomial Ring in x over Finite Field of size 11 with modulus x^5 + 1",
)


class TestPolynomial(unittest.TestCase):
R = PolynomialRing(11, 5)

def test_getitem(self):
x = self.R.gen()
self.assertEqual(x[0], 0)
self.assertEqual(x[1], 1)

def test_is_zero(self):
self.assertTrue(self.R(0).is_zero())
self.assertFalse(self.R(1).is_zero())
Expand All @@ -36,7 +53,37 @@ def test_reduce_coefficients(self):
randint(-2 * self.R.q, 3 * self.R.q) for _ in range(self.R.n)
]
f = self.R(coeffs).reduce_coefficients()
self.assertTrue(all([c < self.R.q for c in f.coeffs]))
self.assertTrue(all([0 <= c < self.R.q for c in f.coeffs]))

def test_equality(self):
for _ in range(100):
f1 = self.R.random_element()
f2 = -f1
self.assertEqual(f1, f1)
if f1.is_zero():
self.assertTrue(f1 == f2)
else:
self.assertFalse(f1 == f2)

self.assertTrue(self.R(0) == 0)
self.assertTrue(self.R(1) == self.R.q + 1)
self.assertTrue(self.R(self.R.q - 1) == -1)

def test_add_failure(self):
f1 = self.R.random_element()
self.assertRaises(NotImplementedError, lambda: f1 + "a")

def test_sub_failure(self):
f1 = self.R.random_element()
self.assertRaises(NotImplementedError, lambda: f1 - "a")

def test_mul_failure(self):
f1 = self.R.random_element()
self.assertRaises(NotImplementedError, lambda: f1 * "a")

def test_pow_failure(self):
f1 = self.R.random_element()
self.assertRaises(TypeError, lambda: f1 ** "a")

def test_add_polynomials(self):
zero = self.R(0)
Expand All @@ -49,6 +96,10 @@ def test_add_polynomials(self):
self.assertEqual(f1 + f2, f2 + f1)
self.assertEqual(f1 + (f2 + f3), (f1 + f2) + f3)

f2 = f1
f2 += f1
self.assertEqual(f1 + f1, f2)

def test_sub_polynomials(self):
zero = self.R(0)
for _ in range(100):
Expand All @@ -58,9 +109,15 @@ def test_sub_polynomials(self):

self.assertEqual(f1 - zero, f1)
self.assertEqual(f3 - f3, zero)
self.assertEqual(f3 - 0, f3)
self.assertEqual(0 - f3, -f3)
self.assertEqual(f1 - f2, -(f2 - f1))
self.assertEqual(f1 - (f2 - f3), (f1 - f2) + f3)

f2 = f1
f2 -= f1
self.assertEqual(f2, zero)

def test_mul_polynomials(self):
zero = self.R(0)
one = self.R(1)
Expand All @@ -73,6 +130,12 @@ def test_mul_polynomials(self):
self.assertEqual(f1 * one, f1)
self.assertEqual(f1 * f2, f2 * f1)
self.assertEqual(f1 * (f2 * f3), (f1 * f2) * f3)
self.assertEqual(2 * f1, f1 + f1)
self.assertEqual(2 * f1, f1 * 2)

f2 = f1
f2 *= f2
self.assertEqual(f1 * f1, f2)

def test_pow_polynomials(self):
one = self.R(1)
Expand All @@ -85,6 +148,14 @@ def test_pow_polynomials(self):
self.assertEqual(f1 * f1 * f1, f1**3)
self.assertRaises(ValueError, lambda: f1 ** (-1))

def test_print(self):
self.assertEqual(str(self.R(0)), "0")
self.assertEqual(str(self.R(1)), "1")
self.assertEqual(str(self.R.gen()), "x")
self.assertEqual(
str(self.R([1, 2, 3, 4, 5])), "1 + 2*x + 3*x^2 + 4*x^3 + 5*x^4"
)


if __name__ == "__main__":
unittest.main()

0 comments on commit ce9404f

Please sign in to comment.