From a6301c5af08d39121c1e1e7dc9ad1b9e9fe45942 Mon Sep 17 00:00:00 2001 From: Tim Ruffing Date: Wed, 4 Mar 2020 21:21:36 +0100 Subject: [PATCH 1/5] Optionally print intermediate values in reference code and make reference code and pseudocode more consistent with each other --- bip-0340.mediawiki | 6 ++-- bip-0340/reference.py | 68 +++++++++++++++++++++++++++++++++---------- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/bip-0340.mediawiki b/bip-0340.mediawiki index 883ef3aab5..b4e5f602e3 100644 --- a/bip-0340.mediawiki +++ b/bip-0340.mediawiki @@ -136,9 +136,9 @@ Input: * The secret key ''sk'': a 32-byte array, freshly generated uniformly at random The algorithm ''PubKey(sk)'' is defined as: -* Let ''d = int(sk)''. -* Fail if ''d = 0'' or ''d ≥ n''. -* Return ''bytes(d⋅G)''. +* Let ''d' = int(sk)''. +* Fail if ''d' = 0'' or ''d' ≥ n''. +* Return ''bytes(d'⋅G)''. Note that we use a very different public key format (32 bytes) than the ones used by existing systems (which typically use elliptic curve points as public keys, or 33-byte or 65-byte encodings of them). A side effect is that ''PubKey(sk) = PubKey(bytes(n - int(sk))'', so every public key has two corresponding secret keys. diff --git a/bip-0340/reference.py b/bip-0340/reference.py index 79f957816d..d6106fdf32 100644 --- a/bip-0340/reference.py +++ b/bip-0340/reference.py @@ -1,6 +1,15 @@ import hashlib import binascii +# Set DEBUG to True to get a detailed debug output including +# intermediate values during key generation, signing, and +# verification. This is implemented via calls to the +# debug_print_vars() function. +# +# If you want to print values on an individual basis, use +# the pretty() function, e.g., print(pretty(foo)). +DEBUG = False + p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 @@ -62,7 +71,7 @@ def lift_x_square_y(b): y = pow(y_sq, (p + 1) // 4, p) if pow(y, 2, p) != y_sq: return None - return [x, y] + return (x, y) def lift_x_even_y(b): P = lift_x_square_y(b) @@ -87,32 +96,37 @@ def has_even_y(P): return y(P) % 2 == 0 def pubkey_gen(seckey): - x = int_from_bytes(seckey) - if not (1 <= x <= n - 1): + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): + debug_print_vars() raise ValueError('The secret key must be an integer in the range 1..n-1.') - P = point_mul(G, x) + P = point_mul(G, d0) return bytes_from_point(P) -def schnorr_sign(msg, seckey0, aux_rand): +def schnorr_sign(msg, seckey, aux_rand): if len(msg) != 32: + debug_print_vars() raise ValueError('The message must be a 32-byte array.') - seckey0 = int_from_bytes(seckey0) - if not (1 <= seckey0 <= n - 1): + d0 = int_from_bytes(seckey) + if not (1 <= d0 <= n - 1): raise ValueError('The secret key must be an integer in the range 1..n-1.') if len(aux_rand) != 32: raise ValueError('aux_rand must be 32 bytes instead of %i.' % len(aux_rand)) - P = point_mul(G, seckey0) - seckey = seckey0 if has_even_y(P) else n - seckey0 - t = xor_bytes(bytes_from_int(seckey), tagged_hash("BIP340/aux", aux_rand)) + P = point_mul(G, d0) + d = d0 if has_even_y(P) else n - d0 + t = xor_bytes(bytes_from_int(d), tagged_hash("BIP340/aux", aux_rand)) k0 = int_from_bytes(tagged_hash("BIP340/nonce", t + bytes_from_point(P) + msg)) % n if k0 == 0: + debug_print_vars() raise RuntimeError('Failure. This happens only with negligible probability.') R = point_mul(G, k0) k = n - k0 if not has_square_y(R) else k0 e = int_from_bytes(tagged_hash("BIP340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n - sig = bytes_from_point(R) + bytes_from_int((k + e * seckey) % n) + sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) if not schnorr_verify(msg, bytes_from_point(P), sig): + debug_print_vars() raise RuntimeError('The signature does not pass verification.') + debug_print_vars() return sig def schnorr_verify(msg, pubkey, sig): @@ -123,26 +137,29 @@ def schnorr_verify(msg, pubkey, sig): if len(sig) != 64: raise ValueError('The signature must be a 64-byte array.') P = lift_x_even_y(pubkey) - if (P is None): - return False r = int_from_bytes(sig[0:32]) s = int_from_bytes(sig[32:64]) - if (r >= p or s >= n): + if (P is None) or (r >= p) or (s >= n): + debug_print_vars() return False e = int_from_bytes(tagged_hash("BIP340/challenge", sig[0:32] + pubkey + msg)) % n R = point_add(point_mul(G, s), point_mul(P, n - e)) if R is None or not has_square_y(R) or x(R) != r: + debug_print_vars() return False + debug_print_vars() return True # # The following code is only used to verify the test vectors. # import csv +import os +import sys def test_vectors(): all_passed = True - with open('test-vectors.csv', newline='') as csvfile: + with open(os.path.join(sys.path[0], 'test-vectors.csv'), newline='') as csvfile: reader = csv.reader(csvfile) reader.__next__() for row in reader: @@ -185,5 +202,26 @@ def test_vectors(): print('Some test vectors failed.') return all_passed +# +# The following code is only used for debugging +# +import inspect + +def pretty(v): + if isinstance(v, bytes): + return '0x' + v.hex() + if isinstance(v, int): + return pretty(bytes_from_int(v)) + if isinstance(v, tuple): + return tuple(map(pretty, v)) + return v + +def debug_print_vars(): + if DEBUG: + frame = inspect.currentframe().f_back + print(' Variables in function ', frame.f_code.co_name, ' at line ', frame.f_lineno, ':', sep='') + for var_name, var_val in frame.f_locals.items(): + print(' ' + var_name.rjust(11, ' '), '==', pretty(var_val)) + if __name__ == '__main__': test_vectors() From 8c5be9197540e1673187ff099b5fe40ce09c9216 Mon Sep 17 00:00:00 2001 From: Tim Ruffing Date: Thu, 12 Mar 2020 21:13:09 +0100 Subject: [PATCH 2/5] Make code and output a little bit more readable --- bip-0340/reference.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bip-0340/reference.py b/bip-0340/reference.py index d6106fdf32..346b639833 100644 --- a/bip-0340/reference.py +++ b/bip-0340/reference.py @@ -33,13 +33,13 @@ def y(P): return P[1] def point_add(P1, P2): - if (P1 is None): + if P1 is None: return P2 - if (P2 is None): + if P2 is None: return P1 - if (x(P1) == x(P2) and y(P1) != y(P2)): + if (x(P1) == x(P2)) and (y(P1) != y(P2)): return None - if (P1 == P2): + if P1 == P2: lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p else: lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p @@ -49,7 +49,7 @@ def point_add(P1, P2): def point_mul(P, n): R = None for i in range(256): - if ((n >> i) & 1): + if (n >> i) & 1: R = point_add(R, P) P = point_add(P, P) return R @@ -90,7 +90,7 @@ def is_square(x): return pow(x, (p - 1) // 2, p) == 1 def has_square_y(P): - return not is_infinity(P) and is_square(y(P)) + return (not is_infinity(P)) and (is_square(y(P))) def has_even_y(P): return y(P) % 2 == 0 @@ -144,7 +144,7 @@ def schnorr_verify(msg, pubkey, sig): return False e = int_from_bytes(tagged_hash("BIP340/challenge", sig[0:32] + pubkey + msg)) % n R = point_add(point_mul(G, s), point_mul(P, n - e)) - if R is None or not has_square_y(R) or x(R) != r: + if (R is None) or (not has_square_y(R)) or (x(R) != r): debug_print_vars() return False debug_print_vars() @@ -168,7 +168,7 @@ def test_vectors(): msg = bytes.fromhex(msg) sig = bytes.fromhex(sig) result = result == 'TRUE' - print('\nTest vector #%-3i: ' % int(index)) + print('\nTest vector', ('#' + index).rjust(3, ' ') + ':') if seckey != '': seckey = bytes.fromhex(seckey) pubkey_actual = pubkey_gen(seckey) From 003d38cedbe1d8f550ea5032c16373c9779d28e3 Mon Sep 17 00:00:00 2001 From: Tim Ruffing Date: Thu, 12 Mar 2020 18:23:07 +0100 Subject: [PATCH 3/5] Fix typo --- bip-0340/test-vectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bip-0340/test-vectors.py b/bip-0340/test-vectors.py index d1a52c823d..9c029ec925 100644 --- a/bip-0340/test-vectors.py +++ b/bip-0340/test-vectors.py @@ -15,7 +15,7 @@ def vector0(): P = point_mul(G, x) assert(y(P) % 2 == 0) - # For historic reasons (pubkey tiebreaker was squareness and not evenness) + # For historical reasons (pubkey tiebreaker was squareness and not evenness) # we should have at least one test vector where the the point reconstructed # from the public key has a square and one where it has a non-square Y # coordinate. In this one Y is non-square. From 07d938a214475929e08df17e725b3904a3429dbf Mon Sep 17 00:00:00 2001 From: Tim Ruffing Date: Tue, 17 Mar 2020 02:13:26 +0100 Subject: [PATCH 4/5] fixup! Optionally print intermediate values in reference code --- bip-0340/reference.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/bip-0340/reference.py b/bip-0340/reference.py index 346b639833..da1e689c57 100644 --- a/bip-0340/reference.py +++ b/bip-0340/reference.py @@ -78,7 +78,7 @@ def lift_x_even_y(b): if P is None: return None else: - return [x(P), y(P) if y(P) % 2 == 0 else p - y(P)] + return (x(P), y(P) if y(P) % 2 == 0 else p - y(P)) def int_from_bytes(b): return int.from_bytes(b, byteorder="big") @@ -90,7 +90,7 @@ def is_square(x): return pow(x, (p - 1) // 2, p) == 1 def has_square_y(P): - return (not is_infinity(P)) and (is_square(y(P))) + return (not is_infinity(P)) and is_square(y(P)) def has_even_y(P): return y(P) % 2 == 0 @@ -98,14 +98,12 @@ def has_even_y(P): def pubkey_gen(seckey): d0 = int_from_bytes(seckey) if not (1 <= d0 <= n - 1): - debug_print_vars() raise ValueError('The secret key must be an integer in the range 1..n-1.') P = point_mul(G, d0) return bytes_from_point(P) def schnorr_sign(msg, seckey, aux_rand): if len(msg) != 32: - debug_print_vars() raise ValueError('The message must be a 32-byte array.') d0 = int_from_bytes(seckey) if not (1 <= d0 <= n - 1): @@ -117,16 +115,14 @@ def schnorr_sign(msg, seckey, aux_rand): t = xor_bytes(bytes_from_int(d), tagged_hash("BIP340/aux", aux_rand)) k0 = int_from_bytes(tagged_hash("BIP340/nonce", t + bytes_from_point(P) + msg)) % n if k0 == 0: - debug_print_vars() raise RuntimeError('Failure. This happens only with negligible probability.') R = point_mul(G, k0) k = n - k0 if not has_square_y(R) else k0 e = int_from_bytes(tagged_hash("BIP340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) + debug_print_vars() if not schnorr_verify(msg, bytes_from_point(P), sig): - debug_print_vars() raise RuntimeError('The signature does not pass verification.') - debug_print_vars() return sig def schnorr_verify(msg, pubkey, sig): From 72657270d8e4d6ef193878f1e743301edfae0e31 Mon Sep 17 00:00:00 2001 From: Tim Ruffing Date: Tue, 17 Mar 2020 02:30:39 +0100 Subject: [PATCH 5/5] When checking test vectors, handle RuntimeException in signing This is better for playing around with the code. Now these these exceptions can really be raised when the verification during signing fails. --- bip-0340/reference.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/bip-0340/reference.py b/bip-0340/reference.py index da1e689c57..6b1645ccd4 100644 --- a/bip-0340/reference.py +++ b/bip-0340/reference.py @@ -122,7 +122,7 @@ def schnorr_sign(msg, seckey, aux_rand): sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n) debug_print_vars() if not schnorr_verify(msg, bytes_from_point(P), sig): - raise RuntimeError('The signature does not pass verification.') + raise RuntimeError('The created signature does not pass verification.') return sig def schnorr_verify(msg, pubkey, sig): @@ -173,13 +173,17 @@ def test_vectors(): print(' Expected key:', pubkey.hex().upper()) print(' Actual key:', pubkey_actual.hex().upper()) aux_rand = bytes.fromhex(aux_rand) - sig_actual = schnorr_sign(msg, seckey, aux_rand) - if sig == sig_actual: - print(' * Passed signing test.') - else: - print(' * Failed signing test.') - print(' Expected signature:', sig.hex().upper()) - print(' Actual signature:', sig_actual.hex().upper()) + try: + sig_actual = schnorr_sign(msg, seckey, aux_rand) + if sig == sig_actual: + print(' * Passed signing test.') + else: + print(' * Failed signing test.') + print(' Expected signature:', sig.hex().upper()) + print(' Actual signature:', sig_actual.hex().upper()) + all_passed = False + except RuntimeError as e: + print(' * Signing test raised exception:', e) all_passed = False result_actual = schnorr_verify(msg, pubkey, sig) if result == result_actual: