Skip to content

Commit

Permalink
feat(LLVM): add codegenerator for saturated field add/sub
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Aug 5, 2024
1 parent 1e34ec2 commit 432a91e
Show file tree
Hide file tree
Showing 15 changed files with 513 additions and 378 deletions.
8 changes: 8 additions & 0 deletions PLANNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ Other tracks are stretch goals, contributions towards them are accepted.
- introduce batchAffine_vartime
- Optimized square_repeated in assembly for Montgomery and Crandall/Pseudo-Mersenne primes
- Optimized elliptic curve directly calling assembly without ADX checks and limited input/output movement in registers or using function multi-versioning.
- LLVM IR:
- use internal or private linkage type
- look into calling conventions like "fast" or "Tail fast"
- check if returning a value from function is propely optimized
compared to in-place result
- use readnone (pure) and readmem attribute for functions
- look into passing parameter as arrays instead of pointers?
- use hot function attribute

### User Experience track

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ proc finalSubMayOverflowImpl*(
ctx.mov scratch[i], a[i]
ctx.sbb scratch[i], M[i]

# If it overflows here, it means that it was
# If it underflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
ctx.sbb scratchReg, 0

Expand Down
4 changes: 2 additions & 2 deletions constantine/math_compiler/codegen_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import
constantine/platforms/abis/nvidia_abi {.all.},
constantine/platforms/abis/c_abi,
constantine/platforms/llvm/[llvm, nvidia_inlineasm],
constantine/platforms/llvm/llvm,
constantine/platforms/primitives,
./ir

export
nvidia_abi, nvidia_inlineasm,
nvidia_abi,
Flag, flag, wrapOpenArrayLenType

# ############################################################
Expand Down
13 changes: 8 additions & 5 deletions constantine/math_compiler/impl_fields_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
../platforms/llvm/llvm,
./ir, ./codegen_nvidia
constantine/platforms/llvm/[llvm, asm_nvidia],
./ir

# ############################################################
#
Expand Down Expand Up @@ -40,8 +40,11 @@ import
# but the carry codegen of madc.hi.cc.u64 has off-by-one
# - https://forums.developer.nvidia.com/t/incorrect-result-of-ptx-code/221067
# - old 32-bit bug: https://forums.developer.nvidia.com/t/wrong-result-returned-by-madc-hi-u64-ptx-instruction-for-specific-operands/196094
#
# See instruction throughput
# - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions

proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
proc finalSubMayOverflow(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand Down Expand Up @@ -74,7 +77,7 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field,
for i in 0 ..< N:
r[i] = bld.slct(scratch[i], a[i], underflowedModulus)

proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
proc finalSubNoOverflow(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand Down Expand Up @@ -354,4 +357,4 @@ proc field_mul_CIOS_sparebit_gen(asy: Assembler_LLVM, cm: CurveMetadata, field:
proc field_mul_gen*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, skipFinalSub = false): FnDef =
## Generate an optimized modular addition kernel
## with parameters `a, b, modulus: Limbs -> Limbs`
return asy.field_mul_CIOS_sparebit_gen(cm, field, skipFinalSub)
return asy.field_mul_CIOS_sparebit_gen(cm, field, skipFinalSub)
166 changes: 166 additions & 0 deletions constantine/math_compiler/impl_fields_sat.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/llvm/[llvm, super_instructions],
./ir, ./codegen_nvidia

# ############################################################
#
# Field arithmetic with saturated limbs
#
# ############################################################
#
# This implements field operations in pure LLVM
# using saturated limbs, i.e. 64-bit words on 64-bit platforms.
#
# This relies on hardware addition-with-carry and substraction-with-borrow
# for efficiency.
#
# As such it is not suitable for platforms with no carry flags such as:
# - WASM
# - MIPS
# - RISC-V
# - Metal
#
# It may be suitable for Intel GPUs as the virtual ISA does support add-carry
#
# It is suitable for:
# - ARM
# - AMD GPUs (for prototyping)
#
# The following backends have better optimizations through assembly:
# - x86: access to ADOX and ADCX interleaved double-carry chain
# - Nvidia: access to multiply accumulate instruction
# and non-interleaved double-carry chain
#
# AMD GPUs may benefits from using 24-bit limbs
# - https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/programmer-references/AMD_OpenCL_Programming_Optimization_Guide2.pdf
# p2-23:
# Generally, the throughput and latency for 32-bit integer operations is the same
# as for single-precision floating point operations.
# 24-bit integer MULs and MADs have four times the throughput of 32-bit integer
# multiplies. 24-bit signed and unsigned integers are natively supported on the
# GCN family of devices. The use of OpenCL built-in functions for mul24 and mad24
# is encouraged. Note that mul24 can be useful for array indexing operations
# Doc from 2015, it might not apply to RDNA family
# - https://free.eol.cn/edu_net/edudown/AMDppt/OpenCL%20Programming%20and%20Optimization%20-%20Part%20I.pdf
# slide 24
#
# - https://chipsandcheese.com/2023/01/07/microbenchmarking-amds-rdna-3-graphics-architecture/
# "Since Turing, Nvidia also achieves very good integer multiplication performance.
# Integer multiplication appears to be extremely rare in shader code,
# and AMD doesn’t seem to have optimized for it.
# 32-bit integer multiplication executes at around a quarter of FP32 rate,
# and latency is pretty high too."

proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array, carry: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
## This is constant-time straightline code.
## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
##
## To be used when the final substraction can
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)

let bld = asy.builder
let fieldTy = cm.getFieldType(field)
let wordTy = cm.getWordType(field)
let scratch = bld.makeArray(fieldTy)
let M = cm.getModulus(field)
let N = M.len

let zero_i1 = constInt(asy.i1_t, 0)
let zero = constInt(wordTy, 0)

# Mask: contains 0xFFFF or 0x0000
let (_, mask) = bld.subborrow(zero, zero, carry)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
var b: ValueRef
(b, scratch[0]) = bld.subborrow(a[0], M[0], zero_i1)
for i in 1 ..< N:
(b, scratch[i]) = bld.subborrow(a[i], M[i], b)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
(b, _) = bld.subborrow(mask, zero, b)

for i in 0 ..< N:
r[i] = bld.select(b, a[i], scratch[i])

proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
## This is constant-time straightline code.
## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
##
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)

let bld = asy.builder
let fieldTy = cm.getFieldType(field)
let scratch = bld.makeArray(fieldTy)
let M = cm.getModulus(field)
let N = M.len

# Now substract the modulus, and test a < M with the last borrow
let zero_i1 = constInt(asy.i1_t, 0)
var b: ValueRef
(b, scratch[0]) = bld.subborrow(a[0], M[0], zero_i1)
for i in 1 ..< N:
(b, scratch[i]) = bld.subborrow(a[i], M[i], b)

# If it underflows here a was smaller than the modulus, which is what we want
for i in 0 ..< N:
r[i] = bld.select(b, a[i], scratch[i])

proc field_add_gen_sat*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field): FnDef =
## Generate an optimized modular addition kernel
## with parameters `a, b, modulus: Limbs -> Limbs`

let procName = cm.genSymbol(block:
case field
of fp: opFpAdd
of fr: opFrAdd)
let fieldTy = cm.getFieldType(field)
let pFieldTy = pointer_t(fieldTy)

let addModTy = function_t(asy.void_t, [pFieldTy, pFieldTy, pFieldTy])
let addModKernel = asy.module.addFunction(cstring procName, addModTy)
let blck = asy.ctx.appendBasicBlock(addModKernel, "addModSatBody")
asy.builder.positionAtEnd(blck)

let bld = asy.builder

let r = bld.asArray(addModKernel.getParam(0), fieldTy)
let a = bld.asArray(addModKernel.getParam(1), fieldTy)
let b = bld.asArray(addModKernel.getParam(2), fieldTy)

let t = bld.makeArray(fieldTy)
let N = cm.getNumWords(field)

var c: ValueRef
let zero = constInt(asy.i1_t, 0)

(c, t[0]) = bld.addcarry(a[0], b[0], zero)
for i in 1 ..< N:
(c, t[i]) = bld.addcarry(a[i], b[i], c)

if cm.getSpareBits(field) >= 1:
asy.finalSubNoOverflow(cm, field, t, t)
else:
asy.finalSubMayOverflow(cm, field, t, t, c)

bld.store(r, t)
bld.retVoid()

return (addModTy, addModKernel)
6 changes: 6 additions & 0 deletions constantine/math_compiler/ir.nim
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ func getFieldType*(cm: CurveMetadata, field: Field): TypeRef {.inline.} =
else:
return cm.fr.fieldTy

func getWordType*(cm: CurveMetadata, field: Field): TypeRef {.inline.} =
if field == fp:
return cm.fp.wordTy
else:
return cm.fr.wordTy

func getNumWords*(cm: CurveMetadata, field: Field): int {.inline.} =
case field
of fp:
Expand Down
32 changes: 20 additions & 12 deletions constantine/platforms/abis/llvm_abi.nim
Original file line number Diff line number Diff line change
Expand Up @@ -601,17 +601,17 @@ type
## An instruction builder represents a point within a basic block and is
## the exclusive means of building instructions using the C interface.

IntPredicate* {.size: sizeof(cint).} = enum
IntEQ = 32 ## equal
IntNE ## not equal
IntUGT ## unsigned greater than
IntUGE ## unsigned greater or equal
IntULT ## unsigned less than
IntULE ## unsigned less or equal
IntSGT ## signed greater than
IntSGE ## signed greater or equal
IntSLT ## signed less than
IntSLE ## signed less or equal
Predicate* {.size: sizeof(cint).} = enum
kEQ = 32 ## equal
kNE ## not equal
kUGT ## unsigned greater than
kUGE ## unsigned greater or equal
kULT ## unsigned less than
kULE ## unsigned less or equal
kSGT ## signed greater than
kSGE ## signed greater or equal
kSLT ## signed less than
kSLE ## signed less or equal

InlineAsmDialect* {.size: sizeof(cint).} = enum
InlineAsmDialectATT
Expand Down Expand Up @@ -675,19 +675,27 @@ proc call2*(

proc add*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildAdd".}
proc addNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWAdd".}
## Addition No Signed Wrap, i.e. guaranteed to not overflow
proc addNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWAdd".}
## Addition No Unsigned Wrap, i.e. guaranteed to not overflow

proc sub*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildSub".}
proc subNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWSub".}
## Substraction No Signed Wrap, i.e. guaranteed to not overflow
proc subNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWSub".}
## Substraction No Unsigned Wrap, i.e. guaranteed to not overflow

proc neg*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNeg".}
proc negNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWNeg".}
## Negation No Signed Wrap, i.e. guaranteed to not overflow
proc negNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWNeg".}
## Negation No Unsigned Wrap, i.e. guaranteed to not overflow

proc mul*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildMul".}
proc mulNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWMul".}
## Multiplication No Signed Wrap, i.e. guaranteed to not overflow
proc mulNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWMul".}
## Multiplication No Unsigned Wrap, i.e. guaranteed to not overflow

proc divU*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildUDiv".}
proc divU_exact*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildExactUDiv".}
Expand All @@ -706,7 +714,7 @@ proc `xor`*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueR
proc `not`*(builder: BuilderRef, val: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNot".}
proc select*(builder: BuilderRef, condition, then, otherwise: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildSelect".}

proc icmp*(builder: BuilderRef, op: IntPredicate, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildICmp".}
proc icmp*(builder: BuilderRef, op: Predicate, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildICmp".}

proc bitcast*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildBitcast".}
proc trunc*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildTrunc".}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func muladd2*(hi, lo: var Ct[uint64], a, b, c1, c2: Ct[uint64]) {.inline.}=
{.emit:["*",lo, " = (NU64)", dblPrec,";"].}

func smul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
## Extended precision multiplication
## Signed extended precision multiplication
## (hi, lo) <- a*b
##
## Inputs are intentionally unsigned
Expand All @@ -103,4 +103,4 @@ func smul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
{.emit:[lo, " = (NU64)", dblPrec,";"].}
else:
{.emit:["*",hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].}
{.emit:["*",lo, " = (NU64)", dblPrec,";"].}
{.emit:["*",lo, " = (NU64)", dblPrec,";"].}
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ func smul128(a, b: Ct[uint64], hi: var Ct[uint64]): Ct[uint64] {.importc:"_mul12
## as we use their unchecked raw representation for cryptography

func smul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} =
## Extended precision multiplication
## Signed extended precision multiplication
## (hi, lo) <- a*b
##
## Inputs are intentionally unsigned
## as we use their unchecked raw representation for cryptography
##
##
## This is constant-time on most hardware
## See: https://www.bearssl.org/ctmul.html
lo = smul128(a, b, hi)
lo = smul128(a, b, hi)
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 432a91e

Please sign in to comment.