Skip to content

Commit

Permalink
initial support for ML-KEM hybrid key exchange groups
Browse files Browse the repository at this point in the history
  • Loading branch information
tomato42 committed Oct 7, 2024
1 parent 0156727 commit eef5315
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 16 deletions.
8 changes: 7 additions & 1 deletion tlslite/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,13 @@ class GroupName(TLSEnum):
brainpoolP512r1tls13 = 33
allEC.extend(list(range(31, 34)))

all = allEC + allFF
# draft-kwiatkowski-tls-ecdhe-mlkem
secp256r1mlkem768 = 0x11EB
x25519mlkem768 = 0x11EC
secp384r1mlkem1024 = 0x11ED
allKEM = [0x11EB, 0x11EC, 0x11ED]

all = allEC + allFF + allKEM

@classmethod
def toRepr(cls, value, blacklist=None):
Expand Down
17 changes: 13 additions & 4 deletions tlslite/handshakesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .constants import CertificateType
from .utils import cryptomath
from .utils import cipherfactory
from .utils.compat import ecdsaAllCurves, int_types
from .utils.compat import ecdsaAllCurves, int_types, ML_KEM_AVAILABLE
from .utils.compression import compression_algo_impls

CIPHER_NAMES = ["chacha20-poly1305",
Expand All @@ -34,9 +34,13 @@
ALL_RSA_SIGNATURE_HASHES = RSA_SIGNATURE_HASHES + ["md5"]
SIGNATURE_SCHEMES = ["Ed25519", "Ed448"]
RSA_SCHEMES = ["pss", "pkcs1"]
CURVE_NAMES = []
if ML_KEM_AVAILABLE:
CURVE_NAMES += ["secp256r1mlkem768", "x25519mlkem768",
"secp384r1mlkem1024"]
# while secp521r1 is the most secure, it's also much slower than the others
# so place it as the last one
CURVE_NAMES = ["x25519", "x448", "secp384r1", "secp256r1",
CURVE_NAMES += ["x25519", "x448", "secp384r1", "secp256r1",
"secp521r1"]
ALL_CURVE_NAMES = CURVE_NAMES + ["secp256k1", "brainpoolP512r1",
"brainpoolP384r1", "brainpoolP256r1"]
Expand All @@ -57,7 +61,8 @@
TLS13_PERMITTED_GROUPS = ["secp256r1", "secp384r1", "secp521r1",
"x25519", "x448", "ffdhe2048",
"ffdhe3072", "ffdhe4096", "ffdhe6144",
"ffdhe8192"]
"ffdhe8192", "secp256r1mlkem768", "x25519mlkem768",
"secp384r1mlkem1024"]
KNOWN_VERSIONS = ((3, 0), (3, 1), (3, 2), (3, 3), (3, 4))
TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm",
"aes128ccm_8", "aes256ccm", "aes256ccm_8"]
Expand Down Expand Up @@ -395,7 +400,11 @@ def _init_key_settings(self):
self.dhParams = None
self.dhGroups = list(ALL_DH_GROUP_NAMES)
self.defaultCurve = "secp256r1"
self.keyShares = ["secp256r1", "x25519"]
if ML_KEM_AVAILABLE:
self.keyShares = ["x25519mlkem768"]
else:
self.keyShares = []
self.keyShares += ["secp256r1", "x25519"]
self.padding_cb = None
self.use_heartbeat_extension = True
self.heartbeat_response_callback = None
Expand Down
143 changes: 142 additions & 1 deletion tlslite/keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from .utils import tlshashlib as hashlib
from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \
X448_ORDER_SIZE
from .utils.compat import int_types
from .utils.compat import int_types, ML_KEM_AVAILABLE
from .utils.codec import DecodeError

if ML_KEM_AVAILABLE:
from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024


class KeyExchange(object):
"""
Common API for calculating Premaster secret
Expand Down Expand Up @@ -1062,3 +1066,140 @@ def calc_shared_key(self, private, peer_share):
S = ecdhYc * private

return numberToByteArray(S.x(), getPointByteSize(ecdhYc))


class KEMKeyExchange(object):
def __init__(self, group, version):
if not ML_KEM_AVAILABLE:
raise ImportError("kyber-py library not installed!")
self.group = group
assert version == (3, 4)
del version

if self.group not in GroupName.allKEM:
raise ValueError("called with wrong group")

if self.group == GroupName.secp256r1mlkem768:
self._classic_group = GroupName.secp256r1
elif self.group == GroupName.x25519mlkem768:
self._classic_group = GroupName.x25519
else:
assert self.group == GroupName.secp384r1mlkem1024
self._classic_group = GroupName.secp384r1

def get_random_private_key(self):
"""Generates a random value to be used as the private key in KEM."""

if self.group not in GroupName.allKEM:
raise ValueError("called with wrong group")
if self.group in (GroupName.secp256r1mlkem768,
GroupName.x25519mlkem768):
pqc_pub_key, pqc_priv_key = ML_KEM_768.keygen()
else:
pqc_pub_key, pqc_priv_key = ML_KEM_1024.keygen()

classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
classic_key = classic_kex.get_random_private_key()

return ((pqc_pub_key, pqc_priv_key), classic_key)

def calc_public_value(self, private):
classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))

classic_pub_key_share = classic_kex.calc_public_value(private[1])

if self.group == GroupName.x25519mlkem768:
return private[0][0] + classic_pub_key_share
return classic_pub_key_share + private[0][0]

def encapsulate_key(self, public):
if self.group == GroupName.secp256r1mlkem768:
classic_key_len = 65
pqc_key_len = 1184
pqc_first = False
ml_kem = ML_KEM_768
elif self.group == GroupName.x25519mlkem768:
classic_key_len = 32
pqc_key_len = 1184
pqc_first = True
ml_kem = ML_KEM_768
else:
assert self.group == GroupName.secp384r1mlkem1024
classic_key_len = 97
pqc_key_len = 1568
pqc_first = False
ml_kem = ML_KEM_1024

if len(public) != classic_key_len + pqc_key_len:
raise ValueError("Invalid key size for the selected group")

if pqc_first:
pqc_key = public[:pqc_key_len]
classic_key_share = public[pqc_key_len:]
else:
classic_key_share = public[:classic_key_len]
pqc_key = public[classic_key_len:]

classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
classic_key = classic_kex.get_random_private_key()
classic_my_key_share = classic_kex.calc_public_value(classic_key)
print("classic key: {0}".format(classic_key_share))
classic_shared_secret = classic_kex.calc_shared_key(
classic_key, classic_key_share)

pqc_shared_secret, pqc_encaps = ml_kem.encaps(pqc_key)

if pqc_first:
shared_secret = pqc_shared_secret + classic_shared_secret
key_encapsulation = pqc_encaps + classic_my_key_share
else:
shared_secret = classic_shared_secret + pqc_shared_secret
key_encapsulation = classic_my_key_share + pqc_encaps

return shared_secret, key_encapsulation

def calc_shared_key(self, private, key_encaps):
if self.group == GroupName.secp256r1mlkem768:
classic_key_len = 65
pqc_key_len = 1088
pqc_first = False
ml_kem = ML_KEM_768
elif self.group == GroupName.x25519mlkem768:
classic_key_len = 32
pqc_key_len = 1088
pqc_first = True
ml_kem = ML_KEM_768
else:
assert self.group == GroupName.secp384r1mlkem1024
classic_key_len = 97
pqc_key_len = 1568
pqc_first = False
ml_kem = ML_KEM_1024

if len(key_encaps) != classic_key_len + pqc_key_len:
raise ValueError("Invalid key size for the selected group. "
"Expected {0}, received {1}".format(
classic_key_len + pqc_key_len,
len(key_encaps)))

if pqc_first:
pqc_key = key_encaps[:pqc_key_len]
classic_key_share = key_encaps[pqc_key_len:]
else:
classic_key_share = key_encaps[:classic_key_len]
pqc_key = key_encaps[classic_key_len:]

classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
print("classic key: {0}".format(classic_key_share))
classic_shared_secret = classic_kex.calc_shared_key(
private[1], classic_key_share)

pqc_shared_secret = ml_kem.decaps(pqc_key, private[0][1])

if pqc_first:
shared_secret = pqc_shared_secret + classic_shared_secret
else:
shared_secret = classic_shared_secret + pqc_shared_secret

return shared_secret

36 changes: 26 additions & 10 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .utils.deprecations import deprecated_params
from .keyexchange import KeyExchange, RSAKeyExchange, DHE_RSAKeyExchange, \
ECDHE_RSAKeyExchange, SRPKeyExchange, ADHKeyExchange, \
AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange
AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange, KEMKeyExchange
from .handshakehelpers import HandshakeHelpers
from .utils.cipherfactory import createAESCCM, createAESCCM_8, \
createAESGCM, createCHACHA20
Expand Down Expand Up @@ -1196,6 +1196,8 @@ def _clientGetServerHello(self, settings, session, clientHello):
@staticmethod
def _getKEX(group, version):
"""Get object for performing key exchange."""
if group in GroupName.allKEM:
return KEMKeyExchange(group, version)
if group in GroupName.allFF:
return FFDHKeyExchange(group, version)
return ECDHKeyExchange(group, version)
Expand All @@ -1209,6 +1211,15 @@ def _genKeyShareEntry(cls, group, version):
share = kex.calc_public_value(private)
return KeyShareEntry().create(group, share, private)

@classmethod
def _KEMEncaps(cls, group, public):
"""Generate the server's KeyShareEntry object with encapsulated secret.
"""
kex = cls._getKEX(group, (3, 4))
shared_sec, key_share_value = kex.encapsulate_key(public)
key_share = KeyShareEntry().create(group, key_share_value, None)
return shared_sec, key_share

@staticmethod
def _getPRFParams(cipher_suite):
"""Return name of hash used for PRF and the hash output size."""
Expand Down Expand Up @@ -2803,16 +2814,21 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite,
(psk is None and privateKey):
self.ecdhCurve = selected_group
kex = self._getKEX(selected_group, version)
key_share = self._genKeyShareEntry(selected_group, version)
if selected_group in GroupName.allKEM:
shared_sec, key_share = self._KEMEncaps(
selected_group,
cl_key_share.key_exchange)
else:
key_share = self._genKeyShareEntry(selected_group, version)

try:
shared_sec = kex.calc_shared_key(key_share.private,
cl_key_share.key_exchange)
except TLSIllegalParameterException as alert:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(alert)):
yield result
try:
shared_sec = kex.calc_shared_key(key_share.private,
cl_key_share.key_exchange)
except TLSIllegalParameterException as alert:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(alert)):
yield result

sh_extensions.append(ServerKeyShareExtension().create(key_share))
elif (psk is not None and
Expand Down
11 changes: 11 additions & 0 deletions tlslite/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,14 @@ def byte_length(val):
ecdsaAllCurves = False
else:
ecdsaAllCurves = True


# kyber-py is an optional dependency
try:
from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024
del ML_KEM_768
del ML_KEM_1024
except ImportError:
ML_KEM_AVAILABLE = False
else:
ML_KEM_AVAILABLE = True

0 comments on commit eef5315

Please sign in to comment.