-
Notifications
You must be signed in to change notification settings - Fork 96
/
elgamal.py
293 lines (230 loc) · 8.81 KB
/
elgamal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
from dataclasses import dataclass
from typing import Any, Iterable, Optional, Union
from .big_integer import bytes_to_hex
from .byte_padding import to_padded_bytes
from .discrete_log import DiscreteLog
from .group import (
ElementModQ,
ElementModP,
g_pow_p,
mult_p,
mult_inv_p,
pow_p,
ZERO_MOD_Q,
TWO_MOD_Q,
rand_range_q,
)
from .hash import hash_elems
from .hmac import get_hmac
from .logs import log_info, log_error
from .utils import get_optional
ElGamalSecretKey = ElementModQ
ElGamalPublicKey = ElementModP
_BLOCK_SIZE = 32
@dataclass
class ElGamalKeyPair:
"""A tuple of an ElGamal secret key and public key."""
secret_key: ElGamalSecretKey
public_key: ElGamalPublicKey
@dataclass
class ElGamalCiphertext:
"""
An "exponential ElGamal ciphertext" (i.e., with the plaintext in the exponent to allow for
homomorphic addition). Create one with `elgamal_encrypt`. Add them with `elgamal_add`.
Decrypt using one of the supplied instance methods.
"""
pad: ElementModP
"""pad or alpha"""
data: ElementModP
"""encrypted data or beta"""
def __eq__(self, other: Any) -> bool:
if isinstance(other, ElGamalCiphertext):
return self.pad == other.pad and self.data == other.data
return False
def decrypt_known_product(self, product: ElementModP) -> int:
"""
Decrypts an ElGamal ciphertext with a "known product" (the blinding factor used in the encryption).
:param product: The known product (blinding factor).
:return: An exponentially encoded plaintext message.
"""
return DiscreteLog().discrete_log(mult_p(self.data, mult_inv_p(product)))
def decrypt(self, secret_key: ElGamalSecretKey) -> int:
"""
Decrypt an ElGamal ciphertext using a known ElGamal secret key.
:param secret_key: The corresponding ElGamal secret key.
:return: An exponentially encoded plaintext message.
"""
return self.decrypt_known_product(pow_p(self.pad, secret_key))
def decrypt_known_nonce(
self, public_key: ElGamalPublicKey, nonce: ElementModQ
) -> int:
"""
Decrypt an ElGamal ciphertext using a known nonce and the ElGamal public key.
:param public_key: The corresponding ElGamal public key.
:param nonce: The secret nonce used to create the ciphertext.
:return: An exponentially encoded plaintext message.
"""
return self.decrypt_known_product(pow_p(public_key, nonce))
def partial_decrypt(self, secret_key: ElGamalSecretKey) -> ElementModP:
"""
Partially Decrypts an ElGamal ciphertext with a known ElGamal secret key.
𝑀_i = 𝐴^𝑠𝑖 mod 𝑝 in the spec
:param secret_key: The corresponding ElGamal secret key.
:return: An exponentially encoded plaintext message.
"""
return pow_p(self.pad, secret_key)
def crypto_hash(self) -> ElementModQ:
"""
Computes a cryptographic hash of this ciphertext.
"""
return hash_elems(self.pad, self.data)
@dataclass
class HashedElGamalCiphertext:
"""
A hashed version of ElGamal Ciphertext with less size restrictions.
Create one with `hashed_elgamal_encrypt`. Add them with `elgamal_add`.
Decrypt using one of the supplied instance methods.
"""
pad: ElementModP
"""pad or alpha"""
data: str
"""encrypted data or beta"""
mac: str
"""message authentication code for hmac"""
def decrypt(
self, secret_key: ElGamalSecretKey, encryption_seed: ElementModQ
) -> Union[bytes, None]:
"""
Decrypt an ElGamal ciphertext using a known ElGamal secret key.
:param secret_key: The corresponding ElGamal secret key.
:param encryption_seed: Encryption seed (Q) for election.
:return: Decrypted plaintext message.
"""
session_key = hash_elems(self.pad, pow_p(self.pad, secret_key))
data_bytes = to_padded_bytes(self.data)
(ciphertext_chunks, bit_length) = _get_chunks(data_bytes)
mac_key = get_hmac(
session_key.to_hex_bytes(),
encryption_seed.to_hex_bytes(),
bit_length,
)
to_mac = self.pad.to_hex_bytes() + data_bytes
mac = bytes_to_hex(get_hmac(mac_key, to_mac))
if mac != self.mac:
log_error("MAC verification failed in decryption.")
return None
data = b""
for i, block in enumerate(ciphertext_chunks):
data_key = get_hmac(
session_key.to_hex_bytes(),
encryption_seed.to_hex_bytes(),
bit_length,
(i + 1),
)
data += bytes([a ^ b for (a, b) in zip(block, data_key)])
return data
def elgamal_keypair_from_secret(a: ElementModQ) -> Optional[ElGamalKeyPair]:
"""
Given an ElGamal secret key (typically, a random number in [2,Q)), returns
an ElGamal keypair, consisting of the given secret key a and public key g^a.
"""
secret_key_int = a
if secret_key_int < 2:
log_error("ElGamal secret key needs to be in [2,Q).")
return None
return ElGamalKeyPair(a, g_pow_p(a))
def elgamal_keypair_random() -> ElGamalKeyPair:
"""
Create a random elgamal keypair
:return: random elgamal key pair
"""
return get_optional(elgamal_keypair_from_secret(rand_range_q(TWO_MOD_Q)))
def elgamal_combine_public_keys(keys: Iterable[ElGamalPublicKey]) -> ElGamalPublicKey:
"""
Combine multiple elgamal public keys into a joint key
:param keys: list of public elgamal keys
:return: joint key of elgamal keys
"""
return mult_p(*keys)
def elgamal_encrypt(
message: int, nonce: ElementModQ, public_key: ElGamalPublicKey
) -> Optional[ElGamalCiphertext]:
"""
Encrypts a set length message with a given random nonce and an ElGamal public key.
:param message: Known length message (m) to elgamal_encrypt; must be an integer in [0,Q).
:param nonce: Randomly chosen nonce in [1,Q).
:param public_key: ElGamal public key.
:return: An `ElGamalCiphertext`.
"""
if nonce == ZERO_MOD_Q:
log_error("ElGamal encryption requires a non-zero nonce")
return None
pad = g_pow_p(nonce)
gpowp_m = g_pow_p(message)
pubkey_pow_n = pow_p(public_key, nonce)
data = mult_p(gpowp_m, pubkey_pow_n)
log_info(f": publicKey: {public_key.to_hex()}")
log_info(f": pad: {pad.to_hex()}")
log_info(f": data: {data.to_hex()}")
return ElGamalCiphertext(pad, data)
def hashed_elgamal_encrypt(
message: bytes,
nonce: ElementModQ,
public_key: ElGamalPublicKey,
encryption_seed: ElementModQ,
) -> HashedElGamalCiphertext:
"""
Encrypts a variable length byte message with a given random nonce and an ElGamal public key.
:param message: message (m) to encrypt; must be in bytes.
:param nonce: Randomly chosen nonce in [1, Q).
:param public_key: ElGamal public key.
:param encryption_seed: Encryption seed (Q) for election.
"""
pad = g_pow_p(nonce)
pubkey_pow_n = pow_p(public_key, nonce)
session_key = hash_elems(pad, pubkey_pow_n)
(message_chunks, bit_length) = _get_chunks(message)
data = b""
for i, block in enumerate(message_chunks):
data_key = get_hmac(
session_key.to_hex_bytes(),
encryption_seed.to_hex_bytes(),
bit_length,
(i + 1),
)
data += bytes([a ^ b for (a, b) in zip(block, data_key)])
mac_key = get_hmac(
session_key.to_hex_bytes(), encryption_seed.to_hex_bytes(), bit_length
)
to_mac = pad.to_hex_bytes() + data
mac = get_hmac(mac_key, to_mac)
log_info(f": publicKey: {public_key.to_hex()}")
log_info(f": pad: {pad.to_hex()}")
log_info(f": data: {data!r}")
log_info(f": mac: {bytes_to_hex(mac)}")
log_info(f"to_mac {to_mac!r}")
return HashedElGamalCiphertext(pad, bytes_to_hex(data), bytes_to_hex(mac))
def _get_chunks(message: bytes) -> tuple[list[bytes], int]:
remainder = len(message) % _BLOCK_SIZE
if remainder:
message += bytes([0 for _n in range(_BLOCK_SIZE - remainder)])
number_of_blocks = int(len(message) / _BLOCK_SIZE)
return (
[
message[_BLOCK_SIZE * i : _BLOCK_SIZE * (i + 1)]
for i in range(number_of_blocks)
],
len(message) * 8,
)
def elgamal_add(*ciphertexts: ElGamalCiphertext) -> ElGamalCiphertext:
"""
Homomorphically accumulates one or more ElGamal ciphertexts by pairwise multiplication. The exponents
of vote counters will add.
"""
assert len(ciphertexts) != 0, "Must have one or more ciphertexts for elgamal_add"
result = ciphertexts[0]
for c in ciphertexts[1:]:
result = ElGamalCiphertext(
mult_p(result.pad, c.pad), mult_p(result.data, c.data)
)
return result