diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 448b70ce..d89bfaaf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -180,10 +180,26 @@ jobs: os: ubuntu-latest python-version: '3.11' opt-deps: ['brotli', 'zstd'] - - name: py3.12with brotli and zstandard + - name: py3.12 with brotli and zstandard os: ubuntu-latest python-version: '3.12' opt-deps: ['brotli', 'zstd'] + - name: py3.9 with kyber-py + os: ubuntu-latest + python-version: "3.9" + opt-deps: ["kyber_py"] + - name: py3.10 with kyber-py + os: ubuntu-latest + python-version: "3.10" + opt-deps: ["kyber_py"] + - name: py3.11 with kyber-py + os: ubuntu-latest + python-version: "3.11" + opt-deps: ["kyber_py"] + - name: py3.12 with kyber-py + os: ubuntu-latest + python-version: "3.12" + opt-deps: ["kyber_py"] # finally test with multiple dependencies installed at the same time - name: py2.7 with m2crypto, pycrypto, gmpy, gmpy2, and brotli os: ubuntu-20.04 @@ -204,22 +220,22 @@ jobs: - name: py3.9 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: 3.9 - opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd', 'kyber_py'] - name: py3.10 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: '3.10' - opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd', 'kyber_py'] - name: py3.11 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: '3.11' # gmpy doesn't build with 3.11 - opt-deps: ['m2crypto', 'gmpy2', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy2', 'brotli', 'zstd', 'kyber_py'] - name: py3.12 with m2crypto, gmpy, gmpy2, brotli, and zstandard os: ubuntu-latest python-version: '3.12' # gmpy doesn't build with 3.12 # coverage to codeclimate can be submitted just once - opt-deps: ['m2crypto', 'gmpy2', 'codeclimate', 'brotli', 'zstd'] + opt-deps: ['m2crypto', 'gmpy2', 'codeclimate', 'brotli', 'zstd', 'kyber_py'] steps: - uses: actions/checkout@v2 if: ${{ !matrix.container }} @@ -346,6 +362,9 @@ jobs: - name: Install zstandard for py3.8 and after if: ${{ contains(matrix.opt-deps, 'zstd') }} run: pip install zstandard + - name: Install kyber_py + if: ${{ contains(matrix.opt-deps, 'kyber_py') }} + run: pip install "https://github.com/GiacomoPope/kyber-py/archive/b187189a514b3327578928c1d4c901d34592678e.zip" - name: Install build dependencies (2.6) if: ${{ matrix.python-version == '2.6' }} run: | diff --git a/scripts/tls.py b/scripts/tls.py index c18dc8ca..3159cbc3 100755 --- a/scripts/tls.py +++ b/scripts/tls.py @@ -34,7 +34,8 @@ GroupName, SignatureScheme from tlslite.handshakesettings import Keypair, VirtualHost from tlslite import __version__ -from tlslite.utils.compat import b2a_hex, a2b_hex, time_stamp +from tlslite.utils.compat import b2a_hex, a2b_hex, time_stamp, \ + ML_KEM_AVAILABLE from tlslite.utils.dns_utils import is_valid_hostname from tlslite.utils.cryptomath import getRandomBytes from tlslite.constants import KeyUpdateMessageType @@ -76,6 +77,10 @@ def printUsage(s=None): print(" GMPY2 : Loaded") else: print(" GMPY2 : Not Loaded") + if ML_KEM_AVAILABLE: + print(" Kyber-py : Loaded") + else: + print(" Kyber-py : Not Loaded") print("") print("Certificate compression algorithms:") diff --git a/tlslite/constants.py b/tlslite/constants.py index 63aa61f1..49617655 100644 --- a/tlslite/constants.py +++ b/tlslite/constants.py @@ -438,7 +438,13 @@ class GroupName(TLSEnum): brainpoolP512r1tls13 = 33 allEC.extend(list(range(31, 34))) - all = allEC + allFF + # draft-kwiatkowski-tls-ecdhe-mlkem + secp256r1mlkem768 = 0x11EB + x25519mlkem768 = 0x11EC + secp384r1mlkem1024 = 0x11ED + allKEM = [0x11EB, 0x11EC, 0x11ED] + + all = allEC + allFF + allKEM @classmethod def toRepr(cls, value, blacklist=None): diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 3a8755ac..d0d8af08 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -10,7 +10,7 @@ from .constants import CertificateType from .utils import cryptomath from .utils import cipherfactory -from .utils.compat import ecdsaAllCurves, int_types +from .utils.compat import ecdsaAllCurves, int_types, ML_KEM_AVAILABLE from .utils.compression import compression_algo_impls CIPHER_NAMES = ["chacha20-poly1305", @@ -34,10 +34,14 @@ ALL_RSA_SIGNATURE_HASHES = RSA_SIGNATURE_HASHES + ["md5"] SIGNATURE_SCHEMES = ["Ed25519", "Ed448"] RSA_SCHEMES = ["pss", "pkcs1"] +CURVE_NAMES = [] +if ML_KEM_AVAILABLE: + CURVE_NAMES += ["secp256r1mlkem768", "x25519mlkem768", + "secp384r1mlkem1024"] # while secp521r1 is the most secure, it's also much slower than the others # so place it as the last one -CURVE_NAMES = ["x25519", "x448", "secp384r1", "secp256r1", - "secp521r1"] +CURVE_NAMES += ["x25519", "x448", "secp384r1", "secp256r1", + "secp521r1"] ALL_CURVE_NAMES = CURVE_NAMES + ["secp256k1", "brainpoolP512r1", "brainpoolP384r1", "brainpoolP256r1"] if ecdsaAllCurves: @@ -57,7 +61,8 @@ TLS13_PERMITTED_GROUPS = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448", "ffdhe2048", "ffdhe3072", "ffdhe4096", "ffdhe6144", - "ffdhe8192"] + "ffdhe8192", "secp256r1mlkem768", "x25519mlkem768", + "secp384r1mlkem1024"] KNOWN_VERSIONS = ((3, 0), (3, 1), (3, 2), (3, 3), (3, 4)) TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm", "aes128ccm_8", "aes256ccm", "aes256ccm_8"] @@ -395,7 +400,11 @@ def _init_key_settings(self): self.dhParams = None self.dhGroups = list(ALL_DH_GROUP_NAMES) self.defaultCurve = "secp256r1" - self.keyShares = ["secp256r1", "x25519"] + if ML_KEM_AVAILABLE: + self.keyShares = ["x25519mlkem768"] + else: + self.keyShares = [] + self.keyShares += ["secp256r1", "x25519"] self.padding_cb = None self.use_heartbeat_extension = True self.heartbeat_response_callback = None diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 2242aad3..6c49a975 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -21,9 +21,13 @@ from .utils import tlshashlib as hashlib from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \ X448_ORDER_SIZE -from .utils.compat import int_types +from .utils.compat import int_types, ML_KEM_AVAILABLE from .utils.codec import DecodeError +if ML_KEM_AVAILABLE: + from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024 + + class KeyExchange(object): """ Common API for calculating Premaster secret @@ -903,7 +907,7 @@ def get_random_private_key(self): """ raise NotImplementedError("Abstract class") - def calc_public_value(self, private): + def calc_public_value(self, private, point_format=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class") @@ -940,10 +944,11 @@ def get_random_private_key(self): needed_bytes = divceil(paramStrength(self.prime) * 2, 8) return bytesToNumber(getRandomBytes(needed_bytes)) - def calc_public_value(self, private): + def calc_public_value(self, private, point_format=None): """ Calculate the public value for given private value. + :param point_format: ignored, used for compatibility with ECDH groups :rtype: int """ dh_Y = powMod(self.generator, private, self.prime) @@ -1021,20 +1026,32 @@ def _get_fun_gen_size(self): else: return x448, bytearray(X448_G), X448_ORDER_SIZE - def calc_public_value(self, private): - """Calculate public value for given private key.""" + def calc_public_value(self, private, point_format='uncompressed'): + """ + Calculate public value for given private key. + + :param private: Private key for the selected key exchange group. + :param str point_format: The point format to use for the + ECDH public key. Applies only to NIST curves. + """ if isinstance(private, ecdsa.keys.SigningKey): - return private.verifying_key.to_string('uncompressed') + return private.verifying_key.to_string(point_format) if self.group in self._x_groups: fun, generator, _ = self._get_fun_gen_size() return fun(private, generator) else: curve = getCurveByName(GroupName.toStr(self.group)) point = curve.generator * private - return bytearray(point.to_bytes('uncompressed')) + return bytearray(point.to_bytes(point_format)) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key,""" + def calc_shared_key(self, private, peer_share, + valid_point_formats=('uncompressed',)): + """ + Calculate the shared key. + + :param set(str) valid_point_formats: list of point formats that + the peer share can be in; ["uncompressed"] by default. + """ if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() @@ -1049,7 +1066,8 @@ def calc_shared_key(self, private, peer_share): curve = getCurveByName(GroupName.toRepr(self.group)) try: abstractPoint = ecdsa.ellipticcurve.AbstractPoint() - point = abstractPoint.from_bytes(curve.curve, peer_share) + point = abstractPoint.from_bytes( + curve.curve, peer_share, valid_encodings=valid_point_formats) ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) @@ -1062,3 +1080,176 @@ def calc_shared_key(self, private, peer_share): S = ecdhYc * private return numberToByteArray(S.x(), getPointByteSize(ecdhYc)) + + +class KEMKeyExchange(object): + """ + Implementation of the Hybrid KEM key exchange groups. + + Caution, KEMs are not symmetric! While they client calls the + same get_random_private_key(), calc_public_value(), and calc_shared_key() + as in FFDH or ECDH, the server calls just the encapsulate_key() method. + """ + + def __init__(self, group, version): + if not ML_KEM_AVAILABLE: + raise TLSInternalError("kyber-py library not installed!") + self.group = group + assert version == (3, 4) + del version + + if self.group not in GroupName.allKEM: + raise TLSInternalError("called with wrong group") + + if self.group == GroupName.secp256r1mlkem768: + self._classic_group = GroupName.secp256r1 + elif self.group == GroupName.x25519mlkem768: + self._classic_group = GroupName.x25519 + else: + assert self.group == GroupName.secp384r1mlkem1024 + self._classic_group = GroupName.secp384r1 + + def get_random_private_key(self): + """ + Generates a random value to be used as the private key in KEM. + + To be used only to generate the KeyShare in ClientHello. + """ + + if self.group not in GroupName.allKEM: + raise TLSInternalError("called with wrong group") + if self.group in (GroupName.secp256r1mlkem768, + GroupName.x25519mlkem768): + pqc_pub_key, pqc_priv_key = ML_KEM_768.keygen() + else: + pqc_pub_key, pqc_priv_key = ML_KEM_1024.keygen() + + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + classic_key = classic_kex.get_random_private_key() + + return ((pqc_pub_key, pqc_priv_key), classic_key) + + def calc_public_value(self, private, point_format='uncompressed'): + """ + Extract public values for the private key. + + To be used only to generate the KeyShare in ClientHello. + + :param str point_format: Point format of the ECDH portion of the + key exchange (effective only for NIST curves, valid is + 'uncompressed' only) + """ + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + + classic_pub_key_share = classic_kex.calc_public_value( + private[1], point_format=point_format) + + if self.group == GroupName.x25519mlkem768: + return private[0][0] + classic_pub_key_share + return classic_pub_key_share + private[0][0] + + @staticmethod + def _split_key_shares(public, pqc_first, pqc_key_len, classic_key_len): + if len(public) != classic_key_len + pqc_key_len: + raise TLSIllegalParameterException( + "Invalid key size for the selected group. " + "Expected: {0}, received: {1}".format( + classic_key_len + pqc_key_len, + len(public))) + + if pqc_first: + pqc_key = public[:pqc_key_len] + classic_key_share = bytearray(public[pqc_key_len:]) + else: + classic_key_share = bytearray(public[:classic_key_len]) + pqc_key = public[classic_key_len:] + + return pqc_key, classic_key_share + + def _group_to_params(self): + """Returns a tuple: + classic_key_len, pqc_ek_key_len, pqc_ciphertext_len, pqc_first, ML_KEM + """ + if self.group == GroupName.secp256r1mlkem768: + classic_key_len = 65 + pqc_key_len = 1184 + pqc_ciphertext_len = 1088 + pqc_first = False + ml_kem = ML_KEM_768 + elif self.group == GroupName.x25519mlkem768: + classic_key_len = 32 + pqc_key_len = 1184 + pqc_ciphertext_len = 1088 + pqc_first = True + ml_kem = ML_KEM_768 + else: + assert self.group == GroupName.secp384r1mlkem1024 + classic_key_len = 97 + pqc_key_len = 1568 + pqc_ciphertext_len = 1568 + pqc_first = False + ml_kem = ML_KEM_1024 + + return classic_key_len, pqc_key_len, pqc_ciphertext_len, pqc_first, \ + ml_kem + + def encapsulate_key(self, public): + """ + Generate a random secret, encapsulate it given the public key, + and return both the random secret and encapsulation of it. + + To be used for generation of KeyShare in ServerHello. + """ + classic_key_len, pqc_key_len, _, pqc_first, ml_kem = \ + self._group_to_params() + + pqc_key, classic_key_share = self._split_key_shares( + public, pqc_first, pqc_key_len, classic_key_len) + + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + classic_key = classic_kex.get_random_private_key() + classic_my_key_share = classic_kex.calc_public_value(classic_key) + classic_shared_secret = classic_kex.calc_shared_key( + classic_key, classic_key_share) + + try: + pqc_shared_secret, pqc_encaps = ml_kem.encaps(pqc_key) + except ValueError: + raise TLSIllegalParameterException( + "Invalid PQC key from peer") + + if pqc_first: + shared_secret = pqc_shared_secret + classic_shared_secret + key_encapsulation = pqc_encaps + classic_my_key_share + else: + shared_secret = classic_shared_secret + pqc_shared_secret + key_encapsulation = classic_my_key_share + pqc_encaps + + return shared_secret, key_encapsulation + + def calc_shared_key(self, private, key_encaps): + """ + Decapsulate the key share received from server. + """ + classic_key_len, _, pqc_key_len, pqc_first, ml_kem = \ + self._group_to_params() + + pqc_key, classic_key_share = self._split_key_shares( + key_encaps, pqc_first, pqc_key_len, classic_key_len) + + classic_kex = ECDHKeyExchange(self._classic_group, (3, 4)) + classic_shared_secret = classic_kex.calc_shared_key( + private[1], classic_key_share) + + try: + pqc_shared_secret = ml_kem.decaps(private[0][1], pqc_key) + except ValueError: + raise TLSIllegalParameterException( + "Error in KEM decapsulation") + + if pqc_first: + shared_secret = pqc_shared_secret + classic_shared_secret + else: + shared_secret = classic_shared_secret + pqc_shared_secret + + return shared_secret diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 7abfe2e3..abb7ce83 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -35,7 +35,7 @@ from .utils.deprecations import deprecated_params from .keyexchange import KeyExchange, RSAKeyExchange, DHE_RSAKeyExchange, \ ECDHE_RSAKeyExchange, SRPKeyExchange, ADHKeyExchange, \ - AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange + AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange, KEMKeyExchange from .handshakehelpers import HandshakeHelpers from .utils.cipherfactory import createAESCCM, createAESCCM_8, \ createAESGCM, createCHACHA20 @@ -1196,6 +1196,8 @@ def _clientGetServerHello(self, settings, session, clientHello): @staticmethod def _getKEX(group, version): """Get object for performing key exchange.""" + if group in GroupName.allKEM: + return KEMKeyExchange(group, version) if group in GroupName.allFF: return FFDHKeyExchange(group, version) return ECDHKeyExchange(group, version) @@ -1209,6 +1211,15 @@ def _genKeyShareEntry(cls, group, version): share = kex.calc_public_value(private) return KeyShareEntry().create(group, share, private) + @classmethod + def _KEMEncaps(cls, group, public): + """Generate the server's KeyShareEntry object with encapsulated secret. + """ + kex = cls._getKEX(group, (3, 4)) + shared_sec, key_share_value = kex.encapsulate_key(public) + key_share = KeyShareEntry().create(group, key_share_value, None) + return shared_sec, key_share + @staticmethod def _getPRFParams(cipher_suite): """Return name of hash used for PRF and the hash output size.""" @@ -2430,7 +2441,7 @@ def _handshakeServerAsyncHelper(self, verifierDB, dhGroups) elif cipherSuite in CipherSuite.ecdheCertSuites or \ cipherSuite in CipherSuite.ecdheEcdsaSuites: - acceptedCurves = self._curveNamesToList(settings) + acceptedCurves = self._curveNamesToList(settings, version) defaultCurve = getattr(GroupName, settings.defaultCurve) keyExchange = ECDHE_RSAKeyExchange(cipherSuite, clientHello, @@ -2457,7 +2468,7 @@ def _handshakeServerAsyncHelper(self, verifierDB, serverHello, settings.dhParams, dhGroups) else: - acceptedCurves = self._curveNamesToList(settings) + acceptedCurves = self._curveNamesToList(settings, version) defaultCurve = getattr(GroupName, settings.defaultCurve) keyExchange = AECDHKeyExchange(cipherSuite, clientHello, serverHello, acceptedCurves, @@ -2803,16 +2814,27 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite, (psk is None and privateKey): self.ecdhCurve = selected_group kex = self._getKEX(selected_group, version) - key_share = self._genKeyShareEntry(selected_group, version) + if selected_group in GroupName.allKEM: + try: + shared_sec, key_share = self._KEMEncaps( + selected_group, + cl_key_share.key_exchange) + except TLSIllegalParameterException as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result + else: + key_share = self._genKeyShareEntry(selected_group, version) - try: - shared_sec = kex.calc_shared_key(key_share.private, - cl_key_share.key_exchange) - except TLSIllegalParameterException as alert: - for result in self._sendError( - AlertDescription.illegal_parameter, - str(alert)): - yield result + try: + shared_sec = kex.calc_shared_key(key_share.private, + cl_key_share.key_exchange) + except TLSIllegalParameterException as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result sh_extensions.append(ServerKeyShareExtension().create(key_share)) elif (psk is not None and @@ -3557,7 +3579,7 @@ def _serverGetClientHello(self, settings, private_key, cert_chain, AlertDescription.decode_error, "Received malformed supported_groups extension"): yield result - serverGroups = self._curveNamesToList(settings) + serverGroups = self._curveNamesToList(settings, version) ecGroupIntersect = getFirstMatching(clientGroups, serverGroups) # RFC 7919 groups serverGroups = self._groupNamesToList(settings) @@ -4913,9 +4935,14 @@ def _sigHashesToList(settings, privateKey=None, certList=None, return sigAlgs @staticmethod - def _curveNamesToList(settings): + def _curveNamesToList(settings, version=(3, 4)): """Convert list of acceptable curves to array identifiers""" - return [getattr(GroupName, val) for val in settings.eccCurves] + ret = [getattr(GroupName, val) for val in settings.eccCurves] + if (settings.maxVersion < (3, 4) and (3, 4) not in settings.versions)\ + or version < (3, 4): + # if we don't support TLS 1.3, filter out KEMs + ret = [i for i in ret if i not in GroupName.allKEM] + return ret @staticmethod def _groupNamesToList(settings): diff --git a/tlslite/utils/compat.py b/tlslite/utils/compat.py index 359de7f5..71945d67 100644 --- a/tlslite/utils/compat.py +++ b/tlslite/utils/compat.py @@ -235,3 +235,14 @@ def byte_length(val): ecdsaAllCurves = False else: ecdsaAllCurves = True + + +# kyber-py is an optional dependency +try: + from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024 + del ML_KEM_768 + del ML_KEM_1024 +except ImportError: + ML_KEM_AVAILABLE = False +else: + ML_KEM_AVAILABLE = True diff --git a/unit_tests/test_tlslite_keyexchange.py b/unit_tests/test_tlslite_keyexchange.py index c6e91076..ee142620 100644 --- a/unit_tests/test_tlslite_keyexchange.py +++ b/unit_tests/test_tlslite_keyexchange.py @@ -35,7 +35,7 @@ from tlslite import VerifierDB from tlslite.extensions import SupportedGroupsExtension, SNIExtension from tlslite.utils.ecc import getCurveByName, getPointByteSize -from tlslite.utils.compat import a2b_hex +from tlslite.utils.compat import a2b_hex, ML_KEM_AVAILABLE import ecdsa from operator import mul try: @@ -45,7 +45,7 @@ from tlslite.keyexchange import KeyExchange, RSAKeyExchange, \ DHE_RSAKeyExchange, SRPKeyExchange, ECDHE_RSAKeyExchange, \ - RawDHKeyExchange, FFDHKeyExchange + RawDHKeyExchange, FFDHKeyExchange, KEMKeyExchange from tlslite.utils.x25519 import x25519, X25519_G, x448, X448_G from tlslite.mathtls import RFC7919_GROUPS from tlslite.utils.python_key import Python_Key @@ -2583,3 +2583,156 @@ def test_calc_shared_secret_for_invalid_sized_input(self): key_share = bytearray(b'\x00' * 10 + b'\x04') with self.assertRaises(TLSIllegalParameterException): kex.calc_shared_key(private, key_share) + + +@unittest.skipIf(not ML_KEM_AVAILABLE, "Kyber-py not installed") +class TestKEMKeyExchange(unittest.TestCase): + def test_init_with_wrong_group(self): + with self.assertRaises(TLSInternalError): + KEMKeyExchange(GroupName.x25519, (3, 4)) + + def test_with_wrong_key_share_size(self): + group = GroupName.x25519mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + with self.assertRaises(TLSIllegalParameterException) as e: + # one byte too long + kex.encapsulate_key(bytearray(32 + 1184 + 1)) + + self.assertIn("Invalid key size", str(e.exception)) + + def test_with_invalid_classic_key_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + alice_key_share = bytearray(alice_key_share) + + alice_key_share[1] ^= 0xff + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.encapsulate_key(alice_key_share) + + self.assertIn("Invalid ECC", str(e.exception)) + + def test_with_invalid_pqc_key_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + alice_key_share = bytearray(alice_key_share) + + alice_key_share[67] = 0xff + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.encapsulate_key(alice_key_share) + + self.assertIn("Invalid PQC", str(e.exception)) + + def test_with_modified_pqc_key_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + alice_key_share = bytearray(alice_key_share) + + alice_key_share[67] = 0x01 + alice_key_share[68] = 0x01 + + bob_shared_secret, bob_key_share = kex.encapsulate_key(alice_key_share) + + alice_shared_secret = kex.calc_shared_key( + alice_private_key, bob_key_share) + + self.assertNotEqual(alice_shared_secret, bob_shared_secret) + + def test_decaps_with_wrong_size_of_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.calc_shared_key(alice_private_key, bytearray(65 + 1088 + 1)) + + self.assertIn("Invalid key size", str(e.exception)) + + def test_decaps_with_invalid_classical_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + + bob_shared_secret, bob_key_share = kex.encapsulate_key(alice_key_share) + bob_key_share = bytearray(bob_key_share) + + bob_key_share[2] ^= 0xff + + with self.assertRaises(TLSIllegalParameterException) as e: + kex.calc_shared_key(alice_private_key, bob_key_share) + + self.assertIn("Invalid ECC", str(e.exception)) + + def test_decaps_with_invalid_pqc_share(self): + group = GroupName.secp256r1mlkem768 + version = (3, 4) + + kex = KEMKeyExchange(group, version) + + alice_private_key = kex.get_random_private_key() + alice_key_share = kex.calc_public_value(alice_private_key) + + bob_shared_secret, bob_key_share = kex.encapsulate_key(alice_key_share) + bob_key_share = bytearray(bob_key_share) + + bob_key_share[68] ^= 0xff + + alice_shared_secret = kex.calc_shared_key( + alice_private_key, bob_key_share) + + self.assertNotEqual(alice_shared_secret, bob_shared_secret) + + def do_kex(self, group): + version = (3, 4) + + alice_kex = KEMKeyExchange(group, version) + + alice_private_key = alice_kex.get_random_private_key() + alice_key_share = alice_kex.calc_public_value(alice_private_key) + + bob_kex = KEMKeyExchange(group, version) + bob_shared_secret, bob_key_share = \ + bob_kex.encapsulate_key(alice_key_share) + + alice_shared_secret = alice_kex.calc_shared_key( + alice_private_key, bob_key_share) + + self.assertEqual(alice_shared_secret, bob_shared_secret) + + def test_x25519_ml_kem_768(self): + group = GroupName.x25519mlkem768 + self.do_kex(group) + + def test_p256_ml_kem_768(self): + group = GroupName.secp256r1mlkem768 + self.do_kex(group) + + def test_p384_ml_kem_1024(self): + group = GroupName.secp384r1mlkem1024 + self.do_kex(group)