From 67a4b2c0e0fe0bcee52089e72849ffb12926fc1d Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Thu, 25 Jul 2024 02:48:51 +0200 Subject: [PATCH] feat(special primes accel): Support Crandall primes / Pseudo-Mersenne Prime fast reduction - closes #11 --- benchmarks/bench_fp.nim | 22 +- constantine.nimble | 6 +- constantine/math/arithmetic/finite_fields.nim | 146 ++++++++---- .../math/arithmetic/limbs_crandall.nim | 221 ++++++++++++++++++ .../named/config_fields_and_curves.nim | 3 + constantine/named/deriv/parser_fields.nim | 27 ++- constantine/named/properties_fields.nim | 8 + tests/math_fields/t_finite_fields.nim | 71 +++--- .../math_fields/t_finite_fields_mulsquare.nim | 2 +- tests/math_fields/t_finite_fields_vs_gmp.nim | 2 +- 10 files changed, 413 insertions(+), 95 deletions(-) create mode 100644 constantine/math/arithmetic/limbs_crandall.nim diff --git a/benchmarks/bench_fp.nim b/benchmarks/bench_fp.nim index 7678340b..c4c8a069 100644 --- a/benchmarks/bench_fp.nim +++ b/benchmarks/bench_fp.nim @@ -55,22 +55,24 @@ proc main() = sqr2xUnrBench(Fp[curve], Iters) rdc2xBench(Fp[curve], Iters) smallSeparator() - sumprodBench(Fp[curve], Iters) - smallSeparator() + when not Fp[curve].isCrandallPrimeField(): + sumprodBench(Fp[curve], Iters) + smallSeparator() toBigBench(Fp[curve], Iters) toFieldBench(Fp[curve], Iters) smallSeparator() invBench(Fp[curve], ExponentIters) invVartimeBench(Fp[curve], ExponentIters) isSquareBench(Fp[curve], ExponentIters) - sqrtBench(Fp[curve], ExponentIters) - sqrtRatioBench(Fp[curve], ExponentIters) - when curve == Bandersnatch: - sqrtVartimeBench(Fp[curve], ExponentIters) - sqrtRatioVartimeBench(Fp[curve], ExponentIters) - # Exponentiation by a "secret" of size ~the curve order - powBench(Fp[curve], ExponentIters) - powVartimeBench(Fp[curve], ExponentIters) + when not Fp[curve].isCrandallPrimeField(): # TODO implement + sqrtBench(Fp[curve], ExponentIters) + sqrtRatioBench(Fp[curve], ExponentIters) + when curve == Bandersnatch: + sqrtVartimeBench(Fp[curve], ExponentIters) + sqrtRatioVartimeBench(Fp[curve], ExponentIters) + # Exponentiation by a "secret" of size ~the curve order + powBench(Fp[curve], ExponentIters) + powVartimeBench(Fp[curve], ExponentIters) separator() main() diff --git a/constantine.nimble b/constantine.nimble index 5ffcabb9..f495507e 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -402,10 +402,10 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[ ("tests/math_fields/t_io_fields", false), # ("tests/math_fields/t_finite_fields.nim", false), # ("tests/math_fields/t_finite_fields_conditional_arithmetic.nim", false), - # ("tests/math_fields/t_finite_fields_mulsquare.nim", false), + ("tests/math_fields/t_finite_fields_mulsquare.nim", false), # ("tests/math_fields/t_finite_fields_sqrt.nim", false), - ("tests/math_fields/t_finite_fields_powinv.nim", false), - # ("tests/math_fields/t_finite_fields_vs_gmp.nim", true), + # ("tests/math_fields/t_finite_fields_powinv.nim", false), + ("tests/math_fields/t_finite_fields_vs_gmp.nim", true), # ("tests/math_fields/t_fp_cubic_root.nim", false), # Double-precision finite fields diff --git a/constantine/math/arithmetic/finite_fields.nim b/constantine/math/arithmetic/finite_fields.nim index 692baf0d..7e2effa7 100644 --- a/constantine/math/arithmetic/finite_fields.nim +++ b/constantine/math/arithmetic/finite_fields.nim @@ -30,7 +30,8 @@ import constantine/platforms/abstractions, constantine/serialization/endians, constantine/named/properties_fields, - ./bigints, ./bigints_montgomery + ./bigints, ./bigints_montgomery, + ./limbs_crandall, ./limbs_extmul when UseASM_X86_64: import ./assembly/limbs_asm_modular_x86 @@ -54,10 +55,13 @@ export Fp, Fr, FF func fromBig*(dst: var FF, src: BigInt) = ## Convert a BigInt to its Montgomery form - when nimvm: - dst.mres.montyResidue_precompute(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord()) + when FF.isCrandallPrimeField(): + dst.mres = src else: - dst.mres.getMont(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) + when nimvm: + dst.mres.montyResidue_precompute(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord()) + else: + dst.mres.getMont(src, FF.getModulus(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func fromBig*[Name: static Algebra](T: type FF[Name], src: BigInt): FF[Name] {.noInit.} = ## Convert a BigInt to its Montgomery form @@ -65,7 +69,10 @@ func fromBig*[Name: static Algebra](T: type FF[Name], src: BigInt): FF[Name] {.n func fromField*(dst: var BigInt, src: FF) {.inline.} = ## Convert a finite-field element to a BigInt in natural representation - dst.fromMont(src.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits()) + when FF.isCrandallPrimeField(): + dst = src.mres + else: + dst.fromMont(src.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits()) func toBig*(src: FF): auto {.noInit, inline.} = ## Convert a finite-field element to a BigInt in natural representation @@ -121,11 +128,17 @@ func isZero*(a: FF): SecretBool = func isOne*(a: FF): SecretBool = ## Constant-time check if one - a.mres == FF.getMontyOne() + when FF.isCrandallPrimeField(): + a.mres.isOne() + else: + a.mres == FF.getMontyOne() func isMinusOne*(a: FF): SecretBool = ## Constant-time check if -1 (mod p) - a.mres == FF.getMontyPrimeMinus1() + when FF.isCrandallPrimeField: + {.error: "Not implemented".} + else: + a.mres == FF.getMontyPrimeMinus1() func isOdd*(a: FF): SecretBool {. error: "Do you need the actual value to be odd\n" & @@ -141,14 +154,20 @@ func setOne*(a: var FF) = # Note: we need 1 in Montgomery residue form # TODO: Nim codegen is not optimal it uses a temporary # Check if the compiler optimizes it away - a.mres = FF.getMontyOne() + when FF.isCrandallPrimeField(): + a.mres.setOne() + else: + a.mres = FF.getMontyOne() func setMinusOne*(a: var FF) = ## Set ``a`` to -1 (mod p) # Note: we need -1 in Montgomery residue form # TODO: Nim codegen is not optimal it uses a temporary # Check if the compiler optimizes it away - a.mres = FF.getMontyPrimeMinus1() + when FF.isCrandallPrimeField(): + {.error: "Not implemented".} + else: + a.mres = FF.getMontyPrimeMinus1() func neg*(r: var FF, a: FF) {.meter.} = ## Negate modulo p @@ -237,19 +256,36 @@ func double*(r: var FF, a: FF) {.meter.} = func prod*(r: var FF, a, b: FF, skipFinalSub: static bool = false) {.meter.} = ## Store the product of ``a`` by ``b`` modulo p into ``r`` ## ``r`` is initialized / overwritten - r.mres.mulMont(a.mres, b.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub) + when FF.isCrandallPrimeField(): + var r2 {.noInit.}: FF.Name.getLimbs2x() + r2.prod(a.mres.limbs, b.mres.limbs) + r.mres.limbs.reduce_crandall_partial(r2, FF.bits(), FF.getCrandallPrimeSubterm()) + when not skipFinalSub: + r.mres.limbs.reduce_crandall_final(FF.bits(), FF.getCrandallPrimeSubterm()) + else: + r.mres.mulMont(a.mres, b.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub) func square*(r: var FF, a: FF, skipFinalSub: static bool = false) {.meter.} = ## Squaring modulo p - r.mres.squareMont(a.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub) + when FF.isCrandallPrimeField(): + var r2 {.noInit.}: FF.Name.getLimbs2x() + r2.square(a.mres.limbs) + r.mres.limbs.reduce_crandall_partial(r2, FF.bits(), FF.getCrandallPrimeSubterm()) + when not skipFinalSub: + r.mres.limbs.reduce_crandall_final(FF.bits(), FF.getCrandallPrimeSubterm()) + else: + r.mres.squareMont(a.mres, FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub) func sumprod*[N: static int](r: var FF, a, b: array[N, FF], skipFinalSub: static bool = false) {.meter.} = ## Compute r <- ⅀aᵢ.bᵢ (mod M) (sum of products) # We rely on FF and Bigints having the same repr to avoid array copies - r.mres.sumprodMont( - cast[ptr array[N, typeof(a[0].mres)]](a.unsafeAddr)[], - cast[ptr array[N, typeof(b[0].mres)]](b.unsafeAddr)[], - FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub) + when FF.isCrandallPrimeField(): + {.error: "Not implemented".} + else: + r.mres.sumprodMont( + cast[ptr array[N, typeof(a[0].mres)]](a.unsafeAddr)[], + cast[ptr array[N, typeof(b[0].mres)]](b.unsafeAddr)[], + FF.getModulus(), FF.getNegInvModWord(), FF.getSpareBits(), skipFinalSub) # ############################################################ # @@ -329,7 +365,10 @@ func inv*(r: var FF, a: FF) = ## Incidentally this avoids extra check ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve - r.mres.invmod(a.mres, FF.getR2modP(), FF.getModulus()) + when FF.isCrandallPrimeField(): + r.mres.invmod(a.mres, FF.getModulus()) + else: + r.mres.invmod(a.mres, FF.getR2modP(), FF.getModulus()) func inv*(a: var FF) = ## Inversion modulo p @@ -347,7 +386,10 @@ func inv_vartime*(r: var FF, a: FF) {.tags: [VarTime].} = ## Incidentally this avoids extra check ## to convert Jacobian and Projective coordinates ## to affine for elliptic curve - r.mres.invmod_vartime(a.mres, FF.getR2modP(), FF.getModulus()) + when FF.isCrandallPrimeField(): + r.mres.invmod_vartime(a.mres, FF.getModulus()) + else: + r.mres.invmod_vartime(a.mres, FF.getR2modP(), FF.getModulus()) func inv_vartime*(a: var FF) {.tags: [VarTime].} = ## Variable-time Inversion modulo p @@ -509,25 +551,31 @@ func pow*(a: var FF, exponent: BigInt) = ## Exponentiation modulo p ## ``a``: a field element to be exponentiated ## ``exponent``: a big integer - const windowSize = 5 # TODO: find best window size for each curves - a.mres.powMont( - exponent, - FF.getModulus(), FF.getMontyOne(), - FF.getNegInvModWord(), windowSize, - FF.getSpareBits() - ) + when FF.isCrandallPrimeField(): + {.error: "Not implemented".} + else: + const windowSize = 5 # TODO: find best window size for each curves + a.mres.powMont( + exponent, + FF.getModulus(), FF.getMontyOne(), + FF.getNegInvModWord(), windowSize, + FF.getSpareBits() + ) func pow*(a: var FF, exponent: openarray[byte]) = ## Exponentiation modulo p ## ``a``: a field element to be exponentiated ## ``exponent``: a big integer in canonical big endian representation - const windowSize = 5 # TODO: find best window size for each curves - a.mres.powMont( - exponent, - FF.getModulus(), FF.getMontyOne(), - FF.getNegInvModWord(), windowSize, - FF.getSpareBits() - ) + when FF.isCrandallPrimeField(): + {.error: "Not implemented".} + else: + const windowSize = 5 # TODO: find best window size for each curves + a.mres.powMont( + exponent, + FF.getModulus(), FF.getMontyOne(), + FF.getNegInvModWord(), windowSize, + FF.getSpareBits() + ) func pow*(a: var FF, exponent: FF) = ## Exponentiation modulo p @@ -557,13 +605,17 @@ func pow_vartime*(a: var FF, exponent: BigInt) = ## - memory access analysis ## - power analysis ## - timing analysis - const windowSize = 5 # TODO: find best window size for each curves - a.mres.powMont_vartime( - exponent, - FF.getModulus(), FF.getMontyOne(), - FF.getNegInvModWord(), windowSize, - FF.getSpareBits() - ) + + when FF.isCrandallPrimeField(): + {.error: "Not implemented".} + else: + const windowSize = 5 # TODO: find best window size for each curves + a.mres.powMont_vartime( + exponent, + FF.getModulus(), FF.getMontyOne(), + FF.getNegInvModWord(), windowSize, + FF.getSpareBits() + ) func pow_vartime*(a: var FF, exponent: openarray[byte]) = ## Exponentiation modulo p @@ -576,13 +628,17 @@ func pow_vartime*(a: var FF, exponent: openarray[byte]) = ## - memory access analysis ## - power analysis ## - timing analysis - const windowSize = 5 # TODO: find best window size for each curves - a.mres.powMont_vartime( - exponent, - FF.getModulus(), FF.getMontyOne(), - FF.getNegInvModWord(), windowSize, - FF.getSpareBits() - ) + + when FF.isCrandallPrimeField(): + {.error: "Not implemented".} + else: + const windowSize = 5 # TODO: find best window size for each curves + a.mres.powMont_vartime( + exponent, + FF.getModulus(), FF.getMontyOne(), + FF.getNegInvModWord(), windowSize, + FF.getSpareBits() + ) func pow_vartime*(a: var FF, exponent: FF) = ## Exponentiation modulo p diff --git a/constantine/math/arithmetic/limbs_crandall.nim b/constantine/math/arithmetic/limbs_crandall.nim new file mode 100644 index 00000000..f88bb729 --- /dev/null +++ b/constantine/math/arithmetic/limbs_crandall.nim @@ -0,0 +1,221 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/abstractions, + ./limbs + +# No exceptions allowed +{.push raises: [], checks: off.} + +# ############################################################ +# +# Multiprecision Crandall prime / +# Pseudo-Mersenne Prime Arithmetic +# +# ############################################################ +# +# Crandall primes have the form p = 2ᵐ-c +# We use special lazily reduced arithmetic +# where reduction is only done when we overflow 2ʷⁿ +# with w the word bitwidth and n the number of words +# to represent p. +# For example for Curve25519, p = 2²⁵⁵-19 and 2ʷⁿ=2²⁵⁶ +# Hence reduction will only happen when overflowing 2²⁵⁶ bits + +# Fast reduction +# ------------------------------------------------------------ + +func reduce_crandall_partial_impl[N: static int]( + r: var Limbs[N], + a: Limbs[2*N], + bits: static int, + c: static SecretWord) = + ## Partial Reduction modulo p + ## with p with special form 2ᵐ-c + ## called "Crandall prime" or Pseudo-Mersenne Prime in the litterature + ## + ## This is a partial reduction that reduces down to + ## 2ᵐ, i.e. it fits in the same amount of word by p + ## but values my be up to p+c + ## + ## Crandal primes allow fast reduction from the fact that + ## 2ᵐ-c ≡ 0 (mod p) + ## <=> 2ᵐ ≡ c (mod p) + ## <=> a2ᵐ+b ≡ ac + b (mod p) + + # In our case we split at 2ʷⁿ with w the word size (32 or 64-bit) + # and N the number of words needed to represent the prime + # hence 2ʷⁿ ≡ 2ʷⁿ⁻ᵐc (mod p), we call this cs (c shifted) + # so a2ʷⁿ+b ≡ a2ʷⁿ⁻ᵐc + b (mod p) + # + # With concrete instantiations: + # for p = 2²⁵⁵-19 (Curve25519) + # 2²⁵⁵ ≡ 19 (mod p) + # 2²⁵⁶ ≡ 2*19 (mod p) + # We rewrite the 510 bits multiplication result as + # a2²⁵⁶+b = a*2*19 + b (mod p) + # + # For Bitcoin/Ethereum, p = 2²⁵⁶-0x1000003D1 = + # p = 2²⁵⁶ - (2³²+2⁹+2⁸+2⁷+2⁶+2⁴+1) + # 2²⁵⁶ ≡ 0x1000003D1 (mod p) + # We rewrite the 512 bits multiplication result as + # a2²⁵⁶+b = a*0x1000003D1 + b (mod p) + # + # Note: on a w-bit architecture, c MUST be less than w-bit + # This is not the case for secp256k1 on 32-bit + # as it's c = 2³²+2⁹+2⁸+2⁷+2⁶+2⁴+1 + # Though as multiplying by 2³² is free + # we can special-case the problem, if there was a + # 32-bit platform with add-with-carry that is still a valuable target. + # (otherwise unsaturated arithmetic is superior) + + const S = (N*WordBitWidth - bits) + const cs = c shl S + static: doAssert 0 <= S and S < WordBitWidth + + var hi: SecretWord + + # First reduction pass + # multiply high-words by c shifted and accumulate in low-words + # assumes cs fits in a single word. + + # (hi, r₀) <- aₙ*cs + a₀ + muladd1(hi, r[0], a[N], cs, a[0]) + staticFor i, 1, N: + # (hi, rᵢ) <- aᵢ₊ₙ*cs + aᵢ + hi + muladd2(hi, r[i], a[i+N], cs, a[i], hi) + + # The first reduction pass may carry in `hi` + # which would be hi*2ʷⁿ ≡ hi*2ʷⁿ⁻ᵐ*c (mod p) + # ≡ hi*cs (mod p) + + # Move all extra bits to hi, i.e. double-word shift + hi = (hi shl S) or (r[N-1] shr (WordBitWidth-S)) + + # High-bit has been "carried" to `hi`, cancel it. + # Note: there might be up to `c` not reduced. + r[N-1] = r[N-1] and (MaxWord shr S) + + # hi*cs (mod p), + # hi has already been shifted so we use `c` instead of `cs` + when N*WordBitWidth == bits: # Secp256k1 only according to eprint/iacr 2018/985 + var t0, t1: SecretWord + mul(t1, t0, hi, c) + + # Second pass + var carry: Carry + addC(carry, r[0], r[0], t0, Carry(0)) + addC(carry, r[1], r[1], t1, carry) + staticFor i, 2, N: + addC(carry, r[i], r[i], Zero, carry) + + # Third pass + mul(t1, t0, SecretWord(carry), c) + addC(carry, r[0], r[0], t0, Carry(0)) + addC(carry, r[1], r[1], t1, carry) + + else: + hi *= c # Cannot overflow + + # Second pass + var carry: Carry + addC(carry, r[0], r[0], hi, Carry(0)) + staticFor i, 1, N: + addC(carry, r[i], r[i], Zero, carry) + +func reduce_crandall_final_impl[N: static int]( + a: var Limbs[N], + bits: int, + c: SecretWord) = + ## Final Reduction modulo p + ## with p with special form 2ᵐ-c + ## called "Crandall prime" or Pseudo-Mersenne Prime in the litterature + ## + ## This reduces `a` from [0, 2ᵐ) to [0, 2ᵐ-c) + let S = (N*WordBitWidth - bits) + let top = MaxWord shr S + debug: doAssert 0 <= S and S < WordBitWidth + + # 1. Substract p = 2ᵐ-c + # p is in the form 0x7FFF...FFFF`c` (7FFF or 3FFF or ... depending of 255-bit 254-bit ...) + var t {.noInit.}: Limbs[N] + var borrow: Borrow + subB(borrow, t[0], a[0], -c, Borrow(0)) + for i in 1 ..< N-1: + subB(borrow, t[i], a[i], MaxWord, borrow) + when N >= 2: + subB(borrow, t[N-1], a[N-1], top, borrow) + + # 2. If underflow, a has the proper reduced result + # otherwise t has the proper reduced result + a.ccopy(t, not SecretBool(borrow)) + +func reduce_crandall_partial*[N: static int]( + r: var Limbs[N], + a: Limbs[2*N], + bits: static int, + c: static SecretWord) = + ## Partial Reduction modulo p + ## with p with special form 2ᵐ-c + ## called "Crandall prime" or Pseudo-Mersenne Prime in the litterature + ## + ## This is a partial reduction that reduces down to + ## 2ᵐ, i.e. it fits in the same amount of word by p + ## but values my be up to p+c + ## + ## Crandal primes allow fast reduction from the fact that + ## 2ᵐ-c ≡ 0 (mod p) + ## <=> 2ᵐ ≡ c (mod p) + ## <=> a2ᵐ+b ≡ ac + b (mod p) + + static: doAssert N*WordBitWidth >= bits + reduce_crandall_partial_impl(r, a, bits, c) + +func reduce_crandall_final*[N: static int]( + a: var Limbs[N], + bits: static int, + c: static SecretWord) = + ## Final Reduction modulo p + ## with p with special form 2ᵐ-c + ## called "Crandall prime" or Pseudo-Mersenne Prime in the litterature + ## + ## This reduces `a` from [0, 2ᵐ) to [0, 2ᵐ-c) + + static: doAssert N*WordBitWidth >= bits + reduce_crandall_final_impl(a, bits, c) + +# lazily reduced arithmetic +# ------------------------------------------------------------ + + +func sum_crandall_impl[N: static int]( + r: var Limbs[N], a, b: Limbs[N], + bits: int, + c: SecretWord) = + + let S = (N*WordBitWidth - bits) + let cs = c shl S + debug: doAssert 0 <= S and S < WordBitWidth + + let overflow1 = r.sum(a, b) + # If there is an overflow, substract 2ˢp = 2ʷⁿ - 2ˢc + # with w the word bitwidth and n the number of words + # to represent p. + # For example for Curve25519, p = 2²⁵⁵-19 and 2ʷⁿ=2²⁵⁶ + + # 0x0000 if no overflow or 0xFFFF if overflow + let mask1 = -SecretWord(overflow1) + + # 2ˢp = 2ʷⁿ - 2ˢc ≡ 2ˢc (mod p) + let overflow2 = r.add(mask1 and cs) + let mask2 = -SecretWord(overflow2) + + # We may carry again, but we just did -2ˢc + # so adding back 2ˢc for the extra 2ʷⁿ bit cannot carry + r[0] += mask2 and cs diff --git a/constantine/named/config_fields_and_curves.nim b/constantine/named/config_fields_and_curves.nim index c3fc397a..330d4bf5 100644 --- a/constantine/named/config_fields_and_curves.nim +++ b/constantine/named/config_fields_and_curves.nim @@ -183,6 +183,7 @@ declareCurves: curve Edwards25519: # Bernstein curve bitwidth: 255 modulus: "0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed" + modulusKind: Crandall(19) # Montgomery form: y² = x³ + 486662x² + x # Edwards form: x² + y² = 1+dx²y² with d = 121665/121666 @@ -221,6 +222,8 @@ declareCurves: curve Secp256k1: # Bitcoin curve bitwidth: 256 modulus: "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" + modulusKind: Crandall(0x1000003D1'u64) + order: "0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" orderBitwidth: 256 eq_form: ShortWeierstrass diff --git a/constantine/named/deriv/parser_fields.nim b/constantine/named/deriv/parser_fields.nim index e21be82a..bf6c6ce3 100644 --- a/constantine/named/deriv/parser_fields.nim +++ b/constantine/named/deriv/parser_fields.nim @@ -24,6 +24,10 @@ import # for example when using the `r2modP` constant in multiple overloads in the same module type + PrimeKind* = enum + kGeneric + kCrandall # Crandall Prime are in the form 2ᵐ-c (also called pseudo-Mersenne primes) + CurveFamily* = enum NoFamily BarretoNaehrig # BN curve @@ -95,6 +99,8 @@ type # Field parameters bitWidth*: NimNode # nnkIntLit modulus*: NimNode # nnkStrLit (hex) + modulusKind*: PrimeKind + modulusKindAssociatedValue*: BiggestInt # Towering nonresidue_fp*: NimNode # nnkIntLit @@ -170,13 +176,20 @@ proc parseCurveDecls*(defs: var seq[CurveParams], curves: NimNode) = curveParams[i][1].expectKind(nnkStmtList) let sectionVal = curveParams[i][1][0] + # Field if sectionId.eqIdent"bitwidth": params.bitWidth = sectionVal elif sectionId.eqident"modulus": params.modulus = sectionVal + elif sectionId.eqIdent"modulusKind": + sectionVal.expectKind(nnkCall) + sectionVal[0].expectIdent"Crandall" + params.modulusKind = kCrandall + params.modulusKindAssociatedValue = sectionVal[1].intVal + + # Curve elif sectionId.eqIdent"family": params.family = parseEnum[CurveFamily]($sectionVal) - elif sectionId.eqIdent"eq_form": params.eq_form = parseEnum[CurveEquationForm]($sectionVal) elif sectionId.eqIdent"coef_a": @@ -217,6 +230,7 @@ proc parseCurveDecls*(defs: var seq[CurveParams], curves: NimNode) = elif sectionId.eqIdent"nonresidue_fp2": params.nonresidue_fp2 = sectionVal + # Pairings elif sectionId.eqIdent"embedding_degree": params.embedding_degree = sectionVal.intVal.int elif sectionId.eqIdent"sexticTwist": @@ -242,6 +256,7 @@ proc genFieldsConstants(defs: seq[CurveParams]): NimNode = var MapCurveBitWidth = nnkBracket.newTree() var MapCurveOrderBitWidth = nnkBracket.newTree() var curveModStmts = newStmtList() + var crandallStmts = newStmtList() for curveDef in defs: @@ -271,6 +286,15 @@ proc genFieldsConstants(defs: seq[CurveParams]): NimNode = ) ) + crandallStmts.add newConstStmt( + exported($curve & "_fp_isCrandall"), + newLit(curveDef.modulusKind == kCrandall) + ) + if curveDef.modulusKind == kCrandall: + crandallStmts.add newConstStmt( + exported($curve & "_fp_CrandallSubTerm"), + newCall(bindsym"uint64", newLit(curveDef.modulusKindAssociatedValue)) + ) # Field Fr if not curveDef.order.isNil: curveDef.orderBitwidth.expectKind(nnkIntLit) @@ -316,6 +340,7 @@ proc genFieldsConstants(defs: seq[CurveParams]): NimNode = exported("CurveBitWidth"), MapCurveBitWidth ) result.add curveModStmts + result.add crandallStmts # const CurveOrderBitSize: array[Curve, int] = ... result.add newConstStmt( exported("CurveOrderBitWidth"), MapCurveOrderBitWidth diff --git a/constantine/named/properties_fields.nim b/constantine/named/properties_fields.nim index abc07b73..21d2e84f 100644 --- a/constantine/named/properties_fields.nim +++ b/constantine/named/properties_fields.nim @@ -97,6 +97,14 @@ template getBigInt*[Name: static Algebra](T: type FF[Name]): untyped = ## Get the underlying BigInt type. typeof(default(T).mres) +template isCrandallPrimeField*(F: type Fr): static bool = false + +macro isCrandallPrimeField*[Name: static Algebra](F: type Fp[Name]): static bool = + result = bindSym($Name & "_fp_isCrandall") + +macro getCrandallPrimeSubterm*[Name: static Algebra](F: type Fp[Name]): static SecretWord = + result = newcall(bindSym"SecretWord", bindSym($Name & "_fp_CrandallSubTerm")) + func bits*[Name: static Algebra](T: type FF[Name]): static int = T.getBigInt().bits diff --git a/tests/math_fields/t_finite_fields.nim b/tests/math_fields/t_finite_fields.nim index 021c045c..79afa4b9 100644 --- a/tests/math_fields/t_finite_fields.nim +++ b/tests/math_fields/t_finite_fields.nim @@ -319,52 +319,55 @@ proc largeField() = check: bool r.isZero() - test "fromMont doesn't need a final substraction with 256-bit prime (full word used)": - block: - let a = Fp[Secp256k1].getMinusOne() - let expected = BigInt[256].fromHex"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2E" + # Outdated tests as Crandall primes / Pseudo-Mersenne primes + # don't use Montgomery representaiton anymore - var r: BigInt[256] - r.fromField(a) + # test "fromMont doesn't need a final substraction with 256-bit prime (full word used)": + # block: + # let a = Fp[Secp256k1].getMinusOne() + # let expected = BigInt[256].fromHex"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2E" - check: bool(r == expected) - block: - var a: Fp[Secp256k1] - var d: FpDbl[Secp256k1] + # var r: BigInt[256] + # r.fromField(a) - # Set Montgomery repr to the largest field element in Montgomery Residue form - a.mres = BigInt[256].fromHex"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2E" - d.limbs2x = (BigInt[512].fromHex"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2E").limbs + # check: bool(r == expected) + # block: + # var a: Fp[Secp256k1] + # var d: FpDbl[Secp256k1] - var r, expected: BigInt[256] + # # Set Montgomery repr to the largest field element in Montgomery Residue form + # a.mres = BigInt[256].fromHex"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2E" + # d.limbs2x = (BigInt[512].fromHex"0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2E").limbs - r.fromField(a) - expected.limbs.redc2xMont(d.limbs2x, Fp[Secp256k1].getModulus().limbs, Fp[Secp256k1].getNegInvModWord(), Fp[Secp256k1].getSpareBits()) + # var r, expected: BigInt[256] - check: bool(r == expected) + # r.fromField(a) + # expected.limbs.redc2xMont(d.limbs2x, Fp[Secp256k1].getModulus().limbs, Fp[Secp256k1].getNegInvModWord(), Fp[Secp256k1].getSpareBits()) - test "fromMont doesn't need a final substraction with 255-bit prime (1 spare bit)": - block: - let a = Fp[Edwards25519].getMinusOne() - let expected = BigInt[255].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec" + # check: bool(r == expected) - var r: BigInt[255] - r.fromField(a) + # test "fromMont doesn't need a final substraction with 255-bit prime (1 spare bit)": + # block: + # let a = Fp[Edwards25519].getMinusOne() + # let expected = BigInt[255].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec" - check: bool(r == expected) - block: - var a: Fp[Edwards25519] - var d: FpDbl[Edwards25519] + # var r: BigInt[255] + # r.fromField(a) + + # check: bool(r == expected) + # block: + # var a: Fp[Edwards25519] + # var d: FpDbl[Edwards25519] - # Set Montgomery repr to the largest field element in Montgomery Residue form - a.mres = BigInt[255].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec" - d.limbs2x = (BigInt[512].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec").limbs + # # Set Montgomery repr to the largest field element in Montgomery Residue form + # a.mres = BigInt[255].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec" + # d.limbs2x = (BigInt[512].fromHex"0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffec").limbs - var r, expected: BigInt[255] + # var r, expected: BigInt[255] - r.fromField(a) - expected.limbs.redc2xMont(d.limbs2x, Fp[Edwards25519].getModulus().limbs, Fp[Edwards25519].getNegInvModWord(), Fp[Edwards25519].getSpareBits()) + # r.fromField(a) + # expected.limbs.redc2xMont(d.limbs2x, Fp[Edwards25519].getModulus().limbs, Fp[Edwards25519].getNegInvModWord(), Fp[Edwards25519].getSpareBits()) - check: bool(r == expected) + # check: bool(r == expected) largeField() diff --git a/tests/math_fields/t_finite_fields_mulsquare.nim b/tests/math_fields/t_finite_fields_mulsquare.nim index 668ace72..c454e8c6 100644 --- a/tests/math_fields/t_finite_fields_mulsquare.nim +++ b/tests/math_fields/t_finite_fields_mulsquare.nim @@ -28,7 +28,7 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed static: doAssert defined(CTT_TEST_CURVES), "This modules requires the -d:CTT_TEST_CURVES compile option" proc sanity(Name: static Algebra) = - test "Squaring 0,1,2 with " & $Name & " [FastSquaring = " & $(Fp[Name].getSpareBits() >= 2) & "]": + test "Squaring 0,1,2 with " & $Algebra(C) & " [FastSquaring = " & $(Fp[Name].getSpareBits() >= 2) & "]": block: # 0² mod var n: Fp[Name] diff --git a/tests/math_fields/t_finite_fields_vs_gmp.nim b/tests/math_fields/t_finite_fields_vs_gmp.nim index 2da68470..218b727c 100644 --- a/tests/math_fields/t_finite_fields_vs_gmp.nim +++ b/tests/math_fields/t_finite_fields_vs_gmp.nim @@ -58,7 +58,7 @@ proc binary_prologue[Name: static Algebra, N: static int]( bTest = rng.random_unsafe(Fp[Name]) # Set modulus to curve modulus - let err = mpz_set_str(p, Fp[Name].getmodulus().toHex(), 0) + let err = mpz_set_str(p, Fp[Name].getModulus().toHex(), 0) doAssert err == 0, "Error on prime for curve " & $Name #########################################################