Skip to content

Commit

Permalink
Merge pull request #512 from gstarovo/ecc_changes
Browse files Browse the repository at this point in the history
Usage of newer ecsda package
  • Loading branch information
tomato42 authored Mar 13, 2024
2 parents 703a0c3 + 5590419 commit 6db0826
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 199 deletions.
13 changes: 8 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: GitHub CI
on:
push:
branches:
- master
- tlslite-ng-0.7
- master
- tlslite-ng-0.7
pull_request:

jobs:
Expand Down Expand Up @@ -275,9 +275,12 @@ jobs:
wget https://files.pythonhosted.org/packages/3b/7e/293d19ccd106119e35db4bf3e111b1895098f618b455b758aa636496cf03/setuptools-28.8.0-py2.py3-none-any.whl
wget https://files.pythonhosted.org/packages/83/53/e120833aa2350db333df89a40dea3b310dd9dabf6f29eaa18934a597dc79/wheel-0.30.0a0-py2.py3-none-any.whl
pip install setuptools-28.8.0-py2.py3-none-any.whl wheel-0.30.0a0-py2.py3-none-any.whl
- name: Install M2Crypto
if: ${{ contains(matrix.opt-deps, 'm2crypto') }}
run: pip install --pre m2crypto
- name: Install M2Crypto for python 2.7
if: ${{ contains(matrix.opt-deps, 'm2crypto') && matrix.python-version == '2.7' }}
run: pip install M2Crypto==0.37.1
- name: Install M2Crypto for python
if: ${{ contains(matrix.opt-deps, 'm2crypto') && matrix.python-version != '2.7' }}
run: pip install --pre M2Crypto
- name: Install tackpy
if: ${{ contains(matrix.opt-deps, 'tackpy') }}
run: pip install tackpy
Expand Down
Empty file added test
Empty file.
59 changes: 42 additions & 17 deletions tlslite/keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify
from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \
ExtensionType, GroupName, ECCurveType, SignatureScheme
from .utils.ecc import decodeX962Point, encodeX962Point, getCurveByName, \
getPointByteSize
from .utils.ecc import getCurveByName, getPointByteSize
from .utils.rsakey import RSAKey
from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \
numBits, numberToByteArray, divceil, numBytes, secureHash
Expand All @@ -25,7 +24,6 @@
from .utils.compat import int_types
from .utils.codec import DecodeError


class KeyExchange(object):
"""
Common API for calculating Premaster secret
Expand Down Expand Up @@ -706,7 +704,15 @@ def makeServerKeyExchange(self, sigHash=None):

kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version)
self.ecdhXs = kex.get_random_private_key()
ecdhYs = kex.calc_public_value(self.ecdhXs)

if isinstance(self.ecdhXs, ecdsa.keys.SigningKey):
ecdhYs = bytearray(
self.ecdhXs.get_verifying_key().to_string(
encoding = 'uncompressed'
)
)
else:
ecdhYs = kex.calc_public_value(self.ecdhXs)

version = self.serverHello.server_version
serverKeyExchange = ServerKeyExchange(self.cipherSuite, version)
Expand Down Expand Up @@ -742,7 +748,14 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange):
kex = ECDHKeyExchange(serverKeyExchange.named_curve,
self.serverHello.server_version)
ecdhXc = kex.get_random_private_key()
self.ecdhYc = kex.calc_public_value(ecdhXc)
if isinstance(ecdhXc, ecdsa.keys.SigningKey):
self.ecdhYc = bytearray(
ecdhXc.get_verifying_key().to_string(
encoding = 'uncompressed'
)
)
else:
self.ecdhYc = kex.calc_public_value(ecdhXc)
return kex.calc_shared_key(ecdhXc, ecdh_Ys)

def makeClientKeyExchange(self):
Expand Down Expand Up @@ -999,7 +1012,7 @@ def get_random_private_key(self):
return getRandomBytes(X448_ORDER_SIZE)
else:
curve = getCurveByName(GroupName.toStr(self.group))
return ecdsa.util.randrange(curve.generator.order())
return ecdsa.keys.SigningKey.generate(curve)

def _get_fun_gen_size(self):
"""Return the function and generator for X25519/X448 KEX."""
Expand All @@ -1010,30 +1023,42 @@ def _get_fun_gen_size(self):

def calc_public_value(self, private):
"""Calculate public value for given private key."""
if isinstance(private, ecdsa.keys.SigningKey):
return private.verifying_key.to_string('uncompressed')
if self.group in self._x_groups:
fun, generator, _ = self._get_fun_gen_size()
return fun(private, generator)
else:
curve = getCurveByName(GroupName.toStr(self.group))
return encodeX962Point(curve.generator * private)
point = curve.generator * private
return bytearray(point.to_bytes('uncompressed'))

def calc_shared_key(self, private, peer_share):
"""Calculate the shared key,"""

if self.group in self._x_groups:
fun, _, size = self._get_fun_gen_size()
if len(peer_share) != size:
raise TLSIllegalParameterException("Invalid key share")
if isinstance(private, ecdsa.keys.SigningKey):
private = bytesToNumber(private.to_string())
S = fun(private, peer_share)
self._non_zero_check(S)
return S
else:
curve = getCurveByName(GroupName.toRepr(self.group))
try:
ecdhYc = decodeX962Point(peer_share,
curve)
except (AssertionError, DecodeError):
raise TLSIllegalParameterException("Invalid ECC point")

S = ecdhYc * private

return numberToByteArray(S.x(), getPointByteSize(ecdhYc))
curve = getCurveByName(GroupName.toRepr(self.group))
try:
abstractPoint = ecdsa.ellipticcurve.AbstractPoint()
point = abstractPoint.from_bytes(curve.curve, peer_share)
ecdhYc = ecdsa.ellipticcurve.Point(
curve.curve, point[0], point[1])

except (AssertionError, DecodeError):
raise TLSIllegalParameterException("Invalid ECC point")
if isinstance(private, ecdsa.keys.SigningKey):
ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private)
ecdh.load_received_public_key_bytes(peer_share)
return bytearray(ecdh.generate_sharedsecret_bytes())
S = ecdhYc * private

return numberToByteArray(S.x(), getPointByteSize(ecdhYc))
27 changes: 1 addition & 26 deletions tlslite/utils/ecc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,10 @@
# See the LICENSE file for legal information regarding use of this file.
"""Methods for dealing with ECC points"""

from .codec import Parser, Writer, DecodeError
from .cryptomath import bytesToNumber, numberToByteArray, numBytes
from .compat import ecdsaAllCurves
import ecdsa
from .compat import ecdsaAllCurves


def decodeX962Point(data, curve=ecdsa.NIST256p):
"""Decode a point from a X9.62 encoding"""
parser = Parser(data)
encFormat = parser.get(1)
if encFormat != 4:
raise DecodeError("Not an uncompressed point encoding")
bytelength = getPointByteSize(curve)
xCoord = bytesToNumber(parser.getFixBytes(bytelength))
yCoord = bytesToNumber(parser.getFixBytes(bytelength))
if parser.getRemainingLength():
raise DecodeError("Invalid length of point encoding for curve")
return ecdsa.ellipticcurve.Point(curve.curve, xCoord, yCoord)


def encodeX962Point(point):
"""Encode a point in X9.62 format"""
bytelength = numBytes(point.curve().p())
writer = Writer()
writer.add(4, 1)
writer.bytes += numberToByteArray(point.x(), bytelength)
writer.bytes += numberToByteArray(point.y(), bytelength)
return writer.bytes

def getCurveByName(curveName):
"""Return curve identified by curveName"""
curveMap = {'secp256r1':ecdsa.NIST256p,
Expand Down
23 changes: 13 additions & 10 deletions unit_tests/test_tlslite_keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@
from tlslite.x509 import X509
from tlslite.x509certchain import X509CertChain
from tlslite.utils.keyfactory import parsePEMKey
from tlslite.utils.codec import Parser
from tlslite.utils.codec import Parser, Writer
from tlslite.utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \
numberToByteArray, isPrime, numBits
numberToByteArray, isPrime, numBytes
from tlslite.mathtls import makeX, makeU, makeK, goodGroupParameters
from tlslite.handshakehashes import HandshakeHashes
from tlslite import VerifierDB
from tlslite.extensions import SupportedGroupsExtension, SNIExtension
from tlslite.utils.ecc import getCurveByName, decodeX962Point, encodeX962Point,\
getPointByteSize
from tlslite.utils.ecc import getCurveByName, getPointByteSize
from tlslite.utils.compat import a2b_hex
import ecdsa
from operator import mul
Expand Down Expand Up @@ -1941,7 +1940,6 @@ def setUp(self):

def test_ECDHE_key_exchange(self):
srv_key_ex = self.keyExchange.makeServerKeyExchange('sha1')

KeyExchange.verifyServerKeyExchange(srv_key_ex,
self.srv_pub_key,
self.client_hello.random,
Expand All @@ -1953,10 +1951,15 @@ def test_ECDHE_key_exchange(self):
curve = getCurveByName(curveName)
generator = curve.generator
cln_Xc = ecdsa.util.randrange(generator.order())
cln_Ys = decodeX962Point(srv_key_ex.ecdh_Ys, curve)
cln_Yc = encodeX962Point(generator * cln_Xc)
abstractPoint = ecdsa.ellipticcurve.AbstractPoint().from_bytes(curve.curve, srv_key_ex.ecdh_Ys)
cln_Ys = ecdsa.ellipticcurve.Point(curve.curve,
abstractPoint[0],
abstractPoint[1])
point = generator * cln_Xc
cln_Yc = point.to_bytes('uncompressed')

cln_key_ex = ClientKeyExchange(self.cipher_suite, (3, 3))

cln_key_ex.createECDH(cln_Yc)

cln_S = cln_Ys * cln_Xc
Expand All @@ -1981,9 +1984,9 @@ def test_ECDHE_key_exchange_with_invalid_CKE(self):
curve = getCurveByName(curveName)
generator = curve.generator
cln_Xc = ecdsa.util.randrange(generator.order())
cln_Ys = decodeX962Point(srv_key_ex.ecdh_Ys, curve)
cln_Yc = encodeX962Point(generator * cln_Xc)

point = generator * cln_Xc
cln_Yc = bytearray(point.to_bytes('uncompressed'))
cln_key_ex = ClientKeyExchange(self.cipher_suite, (3, 3))
cln_key_ex.createECDH(cln_Yc)

Expand Down
142 changes: 1 addition & 141 deletions unit_tests/test_tlslite_utils_ecc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,149 +10,9 @@
except ImportError:
import unittest

from tlslite.utils.ecc import decodeX962Point, encodeX962Point, getCurveByName,\
getPointByteSize
from tlslite.utils.ecc import getCurveByName,getPointByteSize
import ecdsa

class TestEncoder(unittest.TestCase):
def test_encode_P_256_point(self):
point = ecdsa.NIST256p.generator * 200

self.assertEqual(encodeX962Point(point),
bytearray(b'\x04'
# x coordinate
b'\x3a\x53\x5b\xd0\xbe\x46\x6f\xf3\xd8\x56'
b'\xa0\x77\xaa\xd9\x50\x4f\x16\xaa\x5d\x52'
b'\x28\xfc\xd7\xc2\x77\x48\x85\xee\x21\x3f'
b'\x3b\x34'
# y coordinate
b'\x66\xab\xa8\x18\x5b\x33\x41\xe0\xc2\xe3'
b'\xd1\xb3\xae\x69\xe4\x7d\x0f\x01\xd4\xbb'
b'\xd7\x06\xd9\x57\x8b\x0b\x65\xd6\xd3\xde'
b'\x1e\xfe'
))

def test_encode_P_256_point_with_zero_first_byte_on_x(self):
point = ecdsa.NIST256p.generator * 379

self.assertEqual(encodeX962Point(point),
bytearray(b'\x04'
b'\x00\x55\x43\x89\x4a\xf3\xd0\x0e\xd7\xd7'
b'\x40\xab\xdb\xd7\x5c\x96\xb0\x68\x77\xb7'
b'\x87\xdb\x5f\x70\xee\xa7\x8b\x90\xa8\xd7'
b'\xc0\x0a'
b'\xbb\x4c\x85\xa3\xd8\xea\x29\xef\xaa\xfa'
b'\x24\x40\x69\x12\xdd\x84\xd5\xb1\x4d\xc3'
b'\x2b\xf6\x56\xef\x6c\x6b\xd5\x8a\x5d\x94'
b'\x3f\x92'
))

def test_encode_P_256_point_with_zero_first_byte_on_y(self):
point = ecdsa.NIST256p.generator * 43

self.assertEqual(encodeX962Point(point),
bytearray(b'\x04'
b'\x98\x6a\xe2\x50\x6f\x1f\xf1\x04\xd0\x42'
b'\x30\x86\x1d\x8f\x4b\x49\x8f\x4b\xc4\xc6'
b'\xd0\x09\xb3\x0f\x75\x44\xdc\x12\x9b\x82'
b'\xd2\x8d'
b'\x00\x3c\xcc\xc0\xa6\x46\x0e\x0a\xe3\x28'
b'\xa4\xd9\x7d\x3c\x7b\x61\xd8\x6f\xc6\x28'
b'\x9c\x18\x9f\x25\x25\x11\x0c\x44\x1b\xb0'
b'\x7e\x97'
))

def test_encode_P_256_point_with_two_zero_first_bytes_on_x(self):
point = ecdsa.NIST256p.generator * 40393

self.assertEqual(encodeX962Point(point),
bytearray(b'\x04'
b'\x00\x00\x3f\x5f\x17\x8a\xa0\x70\x6c\x42'
b'\x31\xeb\x6e\x54\x95\xaa\x16\x42\xc5\xb8'
b'\xa9\x94\x12\x7c\x89\x46\x5f\x22\x99\x4a'
b'\x42\xf9'
b'\xc2\x48\xb3\x37\x59\x9f\x0c\x2f\x29\x77'
b'\x2e\x25\x6f\x1d\x55\x49\xc8\x9b\xa9\xe5'
b'\x73\x13\x82\xcd\x1e\x3c\xc0\x9d\x10\xd0'
b'\x0b\x55'))

def test_encode_P_521_point(self):
point = ecdsa.NIST521p.generator * 200

self.assertEqual(encodeX962Point(point),
bytearray(b'\x04'
b'\x00\x3e\x2a\x2f\x9f\xd5\x9f\xc3\x8d\xfb'
b'\xde\x77\x26\xa0\xbf\xc6\x48\x2a\x6b\x2a'
b'\x86\xf6\x29\xb8\x34\xa0\x6c\x3d\x66\xcd'
b'\x79\x8d\x9f\x86\x2e\x89\x31\xf7\x10\xc7'
b'\xce\x89\x15\x9f\x35\x8b\x4a\x5c\x5b\xb3'
b'\xd2\xcc\x9e\x1b\x6e\x94\x36\x23\x6d\x7d'
b'\x6a\x5e\x00\xbc\x2b\xbe'
b'\x01\x56\x7a\x41\xcb\x48\x8d\xca\xd8\xe6'
b'\x3a\x3f\x95\xb0\x8a\xf6\x99\x2a\x69\x6a'
b'\x37\xdf\xc6\xa1\x93\xff\xbc\x3f\x91\xa2'
b'\x96\xf3\x3c\x66\x15\x57\x3c\x1c\x06\x7f'
b'\x0a\x06\x4d\x18\xbd\x0c\x81\x4e\xf7\x2a'
b'\x8f\x76\xf8\x7f\x9b\x7d\xff\xb2\xf4\x26'
b'\x36\x43\x43\x86\x11\x89'))

class TestDecoder(unittest.TestCase):
def test_decode_P_256_point(self):
point = ecdsa.NIST256p.generator * 379
data = bytearray(b'\x04'
b'\x00\x55\x43\x89\x4a\xf3\xd0\x0e\xd7\xd7'
b'\x40\xab\xdb\xd7\x5c\x96\xb0\x68\x77\xb7'
b'\x87\xdb\x5f\x70\xee\xa7\x8b\x90\xa8\xd7'
b'\xc0\x0a'
b'\xbb\x4c\x85\xa3\xd8\xea\x29\xef\xaa\xfa'
b'\x24\x40\x69\x12\xdd\x84\xd5\xb1\x4d\xc3'
b'\x2b\xf6\x56\xef\x6c\x6b\xd5\x8a\x5d\x94'
b'\x3f\x92'
)

decoded_point = decodeX962Point(data, ecdsa.NIST256p)

self.assertEqual(point, decoded_point)

def test_decode_P_521_point(self):

data = bytearray(b'\x04'
b'\x01\x7d\x8a\x5d\x11\x03\x4a\xaf\x01\x26'
b'\x5f\x2d\xd6\x2d\x76\xeb\xd8\xbe\x4e\xfb'
b'\x3b\x4b\xd2\x05\x5a\xed\x4c\x6d\x20\xc7'
b'\xf3\xd7\x08\xab\x21\x9e\x34\xfd\x14\x56'
b'\x3d\x47\xd0\x02\x65\x15\xc2\xdd\x2d\x60'
b'\x66\xf9\x15\x64\x55\x7a\xae\x56\xa6\x7a'
b'\x28\x51\x65\x26\x5c\xcc'
b'\x01\xd4\x19\x56\xfa\x14\x6a\xdb\x83\x1c'
b'\xb6\x1a\xc4\x4b\x40\xb1\xcb\xcc\x9e\x4f'
b'\x57\x2c\xb2\x72\x70\xb9\xef\x38\x15\xae'
b'\x87\x1f\x85\x40\x94\xda\x69\xed\x97\xeb'
b'\xdc\x72\x25\x25\x61\x76\xb2\xde\xed\xa2'
b'\xb0\x5c\xca\xc4\x83\x8f\xfb\x54\xae\xe0'
b'\x07\x45\x0b\xbf\x7c\xfc')

point = decodeX962Point(data, ecdsa.NIST521p)
self.assertIsNotNone(point)

self.assertEqual(encodeX962Point(point), data)

def test_decode_with_missing_data(self):
data = bytearray(b'\x04'
b'\x00\x55\x43\x89\x4a\xf3\xd0\x0e\xd7\xd7'
b'\x40\xab\xdb\xd7\x5c\x96\xb0\x68\x77\xb7'
b'\x87\xdb\x5f\x70\xee\xa7\x8b\x90\xa8\xd7'
b'\xc0\x0a'
b'\xbb\x4c\x85\xa3\xd8\xea\x29\xef\xaa\xfa'
b'\x24\x40\x69\x12\xdd\x84\xd5\xb1\x4d\xc3'
b'\x2b\xf6\x56\xef\x6c\x6b\xd5\x8a\x5d\x94'
#b'\x3f\x92'
)

# XXX will change later as decoder in tlslite-ng needs to be updated
with self.assertRaises(SyntaxError):
decodeX962Point(data, ecdsa.NIST256p)

class TestCurveLookup(unittest.TestCase):
def test_with_correct_name(self):
curve = getCurveByName('secp256r1')
Expand Down

0 comments on commit 6db0826

Please sign in to comment.