Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debug print for intermediate values #200

Merged
merged 5 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bip-0340.mediawiki
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ Input:
* The secret key ''sk'': a 32-byte array, freshly generated uniformly at random

The algorithm ''PubKey(sk)'' is defined as:
* Let ''d = int(sk)''.
* Fail if ''d = 0'' or ''d ≥ n''.
* Return ''bytes(d⋅G)''.
* Let ''d' = int(sk)''.
* Fail if ''d' = 0'' or ''d' ≥ n''.
* Return ''bytes(d'⋅G)''.

Note that we use a very different public key format (32 bytes) than the ones used by existing systems (which typically use elliptic curve points as public keys, or 33-byte or 65-byte encodings of them). A side effect is that ''PubKey(sk) = PubKey(bytes(n - int(sk))'', so every public key has two corresponding secret keys.

Expand Down
102 changes: 70 additions & 32 deletions bip-0340/reference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import hashlib
import binascii

# Set DEBUG to True to get a detailed debug output including
# intermediate values during key generation, signing, and
# verification. This is implemented via calls to the
# debug_print_vars() function.
#
# If you want to print values on an individual basis, use
# the pretty() function, e.g., print(pretty(foo)).
DEBUG = False

p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141

Expand All @@ -24,13 +33,13 @@ def y(P):
return P[1]

def point_add(P1, P2):
if (P1 is None):
if P1 is None:
return P2
if (P2 is None):
if P2 is None:
return P1
if (x(P1) == x(P2) and y(P1) != y(P2)):
if (x(P1) == x(P2)) and (y(P1) != y(P2)):
return None
if (P1 == P2):
if P1 == P2:
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p
else:
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p
Expand All @@ -40,7 +49,7 @@ def point_add(P1, P2):
def point_mul(P, n):
R = None
for i in range(256):
if ((n >> i) & 1):
if (n >> i) & 1:
R = point_add(R, P)
P = point_add(P, P)
return R
Expand All @@ -62,14 +71,14 @@ def lift_x_square_y(b):
y = pow(y_sq, (p + 1) // 4, p)
if pow(y, 2, p) != y_sq:
return None
return [x, y]
return (x, y)
real-or-random marked this conversation as resolved.
Show resolved Hide resolved

def lift_x_even_y(b):
P = lift_x_square_y(b)
if P is None:
return None
else:
return [x(P), y(P) if y(P) % 2 == 0 else p - y(P)]
return (x(P), y(P) if y(P) % 2 == 0 else p - y(P))

def int_from_bytes(b):
return int.from_bytes(b, byteorder="big")
Expand All @@ -81,38 +90,39 @@ def is_square(x):
return pow(x, (p - 1) // 2, p) == 1

def has_square_y(P):
return not is_infinity(P) and is_square(y(P))
return (not is_infinity(P)) and is_square(y(P))

def has_even_y(P):
return y(P) % 2 == 0

def pubkey_gen(seckey):
x = int_from_bytes(seckey)
if not (1 <= x <= n - 1):
d0 = int_from_bytes(seckey)
if not (1 <= d0 <= n - 1):
raise ValueError('The secret key must be an integer in the range 1..n-1.')
P = point_mul(G, x)
P = point_mul(G, d0)
return bytes_from_point(P)

def schnorr_sign(msg, seckey0, aux_rand):
def schnorr_sign(msg, seckey, aux_rand):
if len(msg) != 32:
raise ValueError('The message must be a 32-byte array.')
seckey0 = int_from_bytes(seckey0)
if not (1 <= seckey0 <= n - 1):
d0 = int_from_bytes(seckey)
if not (1 <= d0 <= n - 1):
raise ValueError('The secret key must be an integer in the range 1..n-1.')
if len(aux_rand) != 32:
raise ValueError('aux_rand must be 32 bytes instead of %i.' % len(aux_rand))
P = point_mul(G, seckey0)
seckey = seckey0 if has_even_y(P) else n - seckey0
t = xor_bytes(bytes_from_int(seckey), tagged_hash("BIP340/aux", aux_rand))
P = point_mul(G, d0)
d = d0 if has_even_y(P) else n - d0
t = xor_bytes(bytes_from_int(d), tagged_hash("BIP340/aux", aux_rand))
k0 = int_from_bytes(tagged_hash("BIP340/nonce", t + bytes_from_point(P) + msg)) % n
if k0 == 0:
raise RuntimeError('Failure. This happens only with negligible probability.')
R = point_mul(G, k0)
k = n - k0 if not has_square_y(R) else k0
e = int_from_bytes(tagged_hash("BIP340/challenge", bytes_from_point(R) + bytes_from_point(P) + msg)) % n
sig = bytes_from_point(R) + bytes_from_int((k + e * seckey) % n)
sig = bytes_from_point(R) + bytes_from_int((k + e * d) % n)
debug_print_vars()
if not schnorr_verify(msg, bytes_from_point(P), sig):
raise RuntimeError('The signature does not pass verification.')
raise RuntimeError('The created signature does not pass verification.')
return sig

def schnorr_verify(msg, pubkey, sig):
Expand All @@ -123,26 +133,29 @@ def schnorr_verify(msg, pubkey, sig):
if len(sig) != 64:
raise ValueError('The signature must be a 64-byte array.')
P = lift_x_even_y(pubkey)
if (P is None):
return False
r = int_from_bytes(sig[0:32])
s = int_from_bytes(sig[32:64])
if (r >= p or s >= n):
if (P is None) or (r >= p) or (s >= n):
debug_print_vars()
return False
e = int_from_bytes(tagged_hash("BIP340/challenge", sig[0:32] + pubkey + msg)) % n
R = point_add(point_mul(G, s), point_mul(P, n - e))
if R is None or not has_square_y(R) or x(R) != r:
if (R is None) or (not has_square_y(R)) or (x(R) != r):
debug_print_vars()
return False
debug_print_vars()
return True

#
# The following code is only used to verify the test vectors.
#
import csv
import os
import sys

def test_vectors():
all_passed = True
with open('test-vectors.csv', newline='') as csvfile:
with open(os.path.join(sys.path[0], 'test-vectors.csv'), newline='') as csvfile:
reader = csv.reader(csvfile)
reader.__next__()
for row in reader:
Expand All @@ -151,7 +164,7 @@ def test_vectors():
msg = bytes.fromhex(msg)
sig = bytes.fromhex(sig)
result = result == 'TRUE'
print('\nTest vector #%-3i: ' % int(index))
print('\nTest vector', ('#' + index).rjust(3, ' ') + ':')
if seckey != '':
seckey = bytes.fromhex(seckey)
pubkey_actual = pubkey_gen(seckey)
Expand All @@ -160,13 +173,17 @@ def test_vectors():
print(' Expected key:', pubkey.hex().upper())
print(' Actual key:', pubkey_actual.hex().upper())
aux_rand = bytes.fromhex(aux_rand)
sig_actual = schnorr_sign(msg, seckey, aux_rand)
if sig == sig_actual:
print(' * Passed signing test.')
else:
print(' * Failed signing test.')
print(' Expected signature:', sig.hex().upper())
print(' Actual signature:', sig_actual.hex().upper())
try:
sig_actual = schnorr_sign(msg, seckey, aux_rand)
if sig == sig_actual:
print(' * Passed signing test.')
else:
print(' * Failed signing test.')
print(' Expected signature:', sig.hex().upper())
print(' Actual signature:', sig_actual.hex().upper())
all_passed = False
except RuntimeError as e:
print(' * Signing test raised exception:', e)
all_passed = False
result_actual = schnorr_verify(msg, pubkey, sig)
if result == result_actual:
Expand All @@ -185,5 +202,26 @@ def test_vectors():
print('Some test vectors failed.')
return all_passed

#
# The following code is only used for debugging
#
import inspect

def pretty(v):
if isinstance(v, bytes):
return '0x' + v.hex()
if isinstance(v, int):
return pretty(bytes_from_int(v))
if isinstance(v, tuple):
return tuple(map(pretty, v))
return v

def debug_print_vars():
if DEBUG:
frame = inspect.currentframe().f_back
print(' Variables in function ', frame.f_code.co_name, ' at line ', frame.f_lineno, ':', sep='')
for var_name, var_val in frame.f_locals.items():
print(' ' + var_name.rjust(11, ' '), '==', pretty(var_val))

if __name__ == '__main__':
test_vectors()
2 changes: 1 addition & 1 deletion bip-0340/test-vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def vector0():
P = point_mul(G, x)
assert(y(P) % 2 == 0)

# For historic reasons (pubkey tiebreaker was squareness and not evenness)
# For historical reasons (pubkey tiebreaker was squareness and not evenness)
# we should have at least one test vector where the the point reconstructed
# from the public key has a square and one where it has a non-square Y
# coordinate. In this one Y is non-square.
Expand Down