Skip to content

Commit

Permalink
refine some more
Browse files Browse the repository at this point in the history
  • Loading branch information
oxarbitrage committed Nov 24, 2023
1 parent 0b97dc7 commit fd7a736
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/Crypt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The resulting sequence is used in the Salsa20 encryption process.
-}
{-@ ignore iOver64Compute @-}
iOver64Compute :: Integral a => a -> [Word32]
iOver64Compute index = extractBytes 8 $ floor (fromIntegral index / 64 :: Double)
iOver64Compute index = extractBytes8 $ floor (fromIntegral index / 64 :: Double)

-- |Display the calculation of an index over the scalar 64. See `iOver64Compute`.
{-@ ignore iOver64Display @-}
Expand Down
6 changes: 3 additions & 3 deletions src/Doubleround.hs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ doubleroundRDisplay r input

-- |The doubleroundR Keelung expression.
{-@ ignore doubleroundRKeelung @-}
doubleroundRKeelung :: [UInt 32] -> Int -> Comp [UInt 32]
doubleroundRKeelung input r
doubleroundRKeelung :: Int -> [UInt 32] -> Comp [UInt 32]
doubleroundRKeelung r input
| length input == 16 && r == 0 = return input
| length input == 16 && r > 0 = doubleroundKeelung input >>= \dr -> doubleroundRKeelung dr (r - 1)
| length input == 16 && r > 0 = doubleroundKeelung input >>= \dr -> doubleroundRKeelung (r - 1) dr
| otherwise = error "arguments of `doubleroundRKeelung` must be a list of 16 `UInt 32` numbers and a number `Int` of rounds"
8 changes: 4 additions & 4 deletions src/Expansion.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ t3 = [116, 101, 32, 107]
{-@ ignore expand32Compute @-}
expand32Compute :: [Word32] -> [Word32] -> [Word32] -> [Word32]
expand32Compute k0 k1 n
| length k0 == 16 && length k1 == 16 && length n == 16 = salsa20Compute (sort32Compute k0 k1 n) 10
| length k0 == 16 && length k1 == 16 && length n == 16 = salsa20Compute 10 (sort32Compute k0 k1 n)
| otherwise = error "inputs to `expand32Compute` must be 2 lists of 16 `Word32` numbers as k0 and k1; and a \
\list of 16 `Word32` numbers as an `n`"

-- |The expansion function displayed where we have two 16 bytes (k0 and k1).
{-@ ignore expand32Display @-}
expand32Display :: [String] -> [String] -> [String] -> [String]
expand32Display k0 k1 n
| length k0 == 16 && length k1 == 16 && length n == 16 = salsa20Display (sort32Display k0 k1 n) 10
| length k0 == 16 && length k1 == 16 && length n == 16 = salsa20Display 10 (sort32Display k0 k1 n)
| otherwise = error "inputs to `expand32Display` must be 2 lists of 16 `String` strings as a k0 and k1; and a \
\list of 16 `String` strings as an `n`"

Expand All @@ -98,15 +98,15 @@ sort32Display k0 k1 n = numberListToStringList o0 ++ k0 ++ numberListToStringLis
{-@ ignore expand16Compute @-}
expand16Compute :: [Word32] -> [Word32] -> [Word32]
expand16Compute k n
| length k == 16 && length n == 16 = salsa20Compute (sort16Compute k n) 10
| length k == 16 && length n == 16 = salsa20Compute 10 (sort16Compute k n)
| otherwise = error "inputs to `expand16Compute` must be a list of 16 `Word32` numbers as a key and a \
\list of 16 `Word32` numbers as an `n`"

-- |The expansion function displayed where we have one 16 bytes (k).
{-@ ignore expand16Display @-}
expand16Display :: [String] -> [String] -> [String]
expand16Display k n
| length k == 16 && length n == 16 = salsa20Display (sort16Display k n) 10
| length k == 16 && length n == 16 = salsa20Display 10 (sort16Display k n)
| otherwise = error "inputs to `expand16Display` must be a list of 16 `String` strings as a key and a \
\list of 16 `String` strings as an `n`"

Expand Down
36 changes: 18 additions & 18 deletions src/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,40 +45,40 @@ import Doubleround
import Utils

-- |The core expression computed.
{-@ ignore coreCompute @-}
coreCompute :: [Word32] -> Int -> [Word32]
coreCompute input rounds
{-@ coreCompute :: Nat -> { i:[_] | (len i) == 16 } -> { o:[_] | (len o) == 16 } @-}
coreCompute :: Int -> [Word32] -> [Word32]
coreCompute rounds input
| length input == 16 = modMatrix (doubleroundRCompute rounds input) input
| otherwise = error "input to `coreCompute` must be a list of 16 `Word32` numbers"

-- |The core expression as a string.
{-@ ignore coreDisplay @-}
coreDisplay :: [String] -> Int -> [String]
coreDisplay input rounds
{-@ coreDisplay :: Nat -> { i:[_] | (len i) == 16 } -> { o:[_] | (len o) == 16 } @-}
coreDisplay :: Int -> [String] -> [String]
coreDisplay rounds input
| length input == 16 = modMatrixDisplay (doubleroundRDisplay rounds input) input
| otherwise = error "input to `coreDisplay` must be a list of 16 `String` strings"

-- |The core Keelung expression computed.
{-@ ignore coreKeelung @-}
coreKeelung :: [UInt 32] -> Int -> Comp [UInt 32]
coreKeelung input rounds
coreKeelung :: Int -> [UInt 32] -> Comp [UInt 32]
coreKeelung rounds input
| length input == 16 = do
dr <- doubleroundRKeelung input rounds
dr <- doubleroundRKeelung rounds input
return $ modMatrixKeelung dr input
| otherwise = error "input to `coreCompute` must be a list of 16 `Word32` numbers"

-- | The salsa20 expression computed.
{-@ ignore salsa20Compute @-}
salsa20Compute :: [Word32] -> Int -> [Word32]
salsa20Compute input rounds
| length input == 64 = aument $ coreCompute (Utils.reduce input) rounds
{-@ salsa20Compute :: Nat -> { i:[_] | (len i) == 64 } -> { o:[_] | (len o) == 64 } @-}
salsa20Compute :: Int -> [Word32] -> [Word32]
salsa20Compute rounds input
| length input == 64 && rounds >= 0 = aument $ coreCompute rounds (Utils.reduce input)
| otherwise = error "input to `salsa20Compute` must be a list of 64 `Word32` numbers"

-- |The salsa20 expression as a string using `coreDisplay`. Call with r = 1, which is one round of doubleround.
{-@ ignore salsa20Display @-}
salsa20Display :: [String] -> Int -> [String]
salsa20Display input rounds
| length input == 64 = aumentDisplay $ coreDisplay (reduceDisplay input) rounds
salsa20Display :: Int ->[String] -> [String]
salsa20Display rounds input
| length input == 64 = aumentDisplay $ coreDisplay rounds (reduceDisplay input)
| otherwise = error "input to `salsa20Display` must be a list of 64 `String` strings"

-- | The salsa20 Keelung expression computed.
Expand All @@ -87,12 +87,12 @@ salsa20Keelung :: [UInt 32] -> Comp [UInt 32]
salsa20Keelung input
| length input == 64 = do
let new_input = reduceKeelung input
core <- coreKeelung new_input 10
core <- coreKeelung 10 new_input
return $ aumentKeelung core
| otherwise = error "input to `salsa20Compute` must be a list of 64 `Word32` numbers"

-- |Execute `salsa20` a specified number of times, this is not part of the protocol and just used in a test case.
{-@ ignore salsa20powerCompute @-}
salsa20powerCompute :: [Word32] -> Int -> [Word32]
salsa20powerCompute input 0 = input
salsa20powerCompute input power = salsa20powerCompute (salsa20Compute input 10) (power - 1)
salsa20powerCompute input power = salsa20powerCompute (salsa20Compute 10 input) (power - 1)
81 changes: 69 additions & 12 deletions src/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This module provides general utility functions used in the creation of the Salsa

module Utils
(
littleendian, extractBytes, displayBytes, Utils.reduce, reduceDisplay, reduceKeelung,
littleendian, extractBytes4, extractBytes8, displayBytes, Utils.reduce, reduceDisplay, reduceKeelung,
aument, aumentDisplay, aumentKeelung,
modMatrix, numberListToStringList, transpose, modMatrixDisplay, modMatrixKeelung,
eitherListToNumberList, eitherListToStringList,
Expand All @@ -32,6 +32,7 @@ import Keelung hiding (input, eq)
import Operators()

-- | Encode a vector as a word using little-endian byte order.
{-@ littleendian :: { i:[_] | (len i) == 4 } -> _ @-}
littleendian :: [Word32] -> Word32
littleendian bytes = sum [byte `Data.Bits.shiftL` (8 * i) | (i, byte) <- zip [0..] bytes]

Expand All @@ -45,9 +46,24 @@ littleendianKeelung :: [UInt 32] -> UInt 32
littleendianKeelung bytes = sum [byte `Keelung.shiftL` (8 * i) | (i, byte) <- zip [0..] bytes]

-- | Extract a specified number of bytes from a Word32.
extractBytes :: Int -> Word32 -> [Word32]
extractBytes numBytes w =
[ Data.Bits.shiftR w (8 * i) Data.Bits..&. 0xff | i <- [0 .. numBytes - 1] ]
--{-@ ignore extractBytes @-}
{-@ extractBytes4 :: _ -> { o:[_] | (len o) == 4 } @-}
extractBytes4 :: Word32 -> [Word32]
extractBytes4 w = [ Data.Bits.shiftR w (8 * 0) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 1) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 2) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 3) Data.Bits..&. 0xff]

{-@ extractBytes8 :: _ -> { o:[_] | (len o) == 8 } @-}
extractBytes8 :: Word32 -> [Word32]
extractBytes8 w = [ Data.Bits.shiftR w (8 * 0) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 1) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 2) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 3) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 4) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 5) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 6) Data.Bits..&. 0xff,
Data.Bits.shiftR w (8 * 7) Data.Bits..&. 0xff]

-- | Display a specified number of bytes from a string as a list of strings.
displayBytes :: Int -> String -> [String]
Expand All @@ -61,8 +77,25 @@ extractBytesKeelung numBytes w =
[ Keelung.shiftR w (8 * i) Keelung..&. 0xff | i <- [0 .. numBytes - 1] ]

-- |Reduce a matrix of 64 elements to a matrix of 16 elements by using `littleendian` encoding.
{-@ reduce :: { i:[_] | (len i) == 64 } -> { o:[_] | (len o) == 16 } @-}
reduce :: [Word32] -> [Word32]
reduce input = map littleendian $ chunksOf 4 input
reduce input = map littleendian [
[input!!0, input!!1, input!!2, input!!3],
[input!!4, input!!5, input!!6, input!!7],
[input!!8, input!!9, input!!10, input!!11],
[input!!12, input!!13, input!!14, input!!15],
[input!!16, input!!17, input!!18, input!!19],
[input!!20, input!!21, input!!22, input!!23],
[input!!24, input!!25, input!!26, input!!27],
[input!!28, input!!29, input!!30, input!!31],
[input!!32, input!!33, input!!34, input!!35],
[input!!36, input!!37, input!!38, input!!39],
[input!!40, input!!41, input!!42, input!!43],
[input!!44, input!!45, input!!46, input!!47],
[input!!48, input!!49, input!!50, input!!51],
[input!!52, input!!53, input!!54, input!!55],
[input!!56, input!!57, input!!58, input!!59],
[input!!60, input!!61, input!!62, input!!63]]

-- |Reduce a matrix of 64 elements to a matrix of 16 elements by using `littleendianDisplay` encoding.
reduceDisplay :: [String] -> [String]
Expand All @@ -73,8 +106,25 @@ reduceKeelung :: [UInt 32] -> [UInt 32]
reduceKeelung input = map littleendianKeelung $ chunksOf 4 input

-- |Aument a matrix of 16 elements to one of 64 elements by using `extractBytes`.
{-@ aument :: { i:[Word32] | (len i) == 16 } -> { o:[Word32] | (len o) == 64 } @-}
aument :: [Word32] -> [Word32]
aument = concatMap $ extractBytes 4
aument input = [
extractBytes4 (input!!0)!!0, extractBytes4 (input!!0)!!1, extractBytes4 (input!!0)!!2, extractBytes4 (input!!0)!!3,
extractBytes4 (input!!1)!!0, extractBytes4 (input!!1)!!1, extractBytes4 (input!!1)!!2, extractBytes4 (input!!1)!!3,
extractBytes4 (input!!2)!!0, extractBytes4 (input!!2)!!1, extractBytes4 (input!!2)!!2, extractBytes4 (input!!2)!!3,
extractBytes4 (input!!3)!!0, extractBytes4 (input!!3)!!1, extractBytes4 (input!!3)!!2, extractBytes4 (input!!3)!!3,
extractBytes4 (input!!4)!!0, extractBytes4 (input!!4)!!1, extractBytes4 (input!!4)!!2, extractBytes4 (input!!4)!!3,
extractBytes4 (input!!5)!!0, extractBytes4 (input!!5)!!1, extractBytes4 (input!!5)!!2, extractBytes4 (input!!5)!!3,
extractBytes4 (input!!6)!!0, extractBytes4 (input!!6)!!1, extractBytes4 (input!!6)!!2, extractBytes4 (input!!6)!!3,
extractBytes4 (input!!7)!!0, extractBytes4 (input!!7)!!1, extractBytes4 (input!!7)!!2, extractBytes4 (input!!7)!!3,
extractBytes4 (input!!8)!!0, extractBytes4 (input!!8)!!1, extractBytes4 (input!!8)!!2, extractBytes4 (input!!8)!!3,
extractBytes4 (input!!9)!!0, extractBytes4 (input!!9)!!1, extractBytes4 (input!!9)!!2, extractBytes4 (input!!9)!!3,
extractBytes4 (input!!10)!!0, extractBytes4 (input!!10)!!1, extractBytes4 (input!!10)!!2, extractBytes4 (input!!10)!!3,
extractBytes4 (input!!11)!!0, extractBytes4 (input!!11)!!1, extractBytes4 (input!!11)!!2, extractBytes4 (input!!11)!!3,
extractBytes4 (input!!12)!!0, extractBytes4 (input!!12)!!1, extractBytes4 (input!!12)!!2, extractBytes4 (input!!12)!!3,
extractBytes4 (input!!13)!!0, extractBytes4 (input!!13)!!1, extractBytes4 (input!!13)!!2, extractBytes4 (input!!13)!!3,
extractBytes4 (input!!14)!!0, extractBytes4 (input!!14)!!1, extractBytes4 (input!!14)!!2, extractBytes4 (input!!14)!!3,
extractBytes4 (input!!15)!!0, extractBytes4 (input!!15)!!1, extractBytes4 (input!!15)!!2, extractBytes4 (input!!15)!!3]

-- |Aument a matrix of 16 elements to one of 64 elements by using `displayBytes`.
aumentDisplay :: [String] -> [String]
Expand All @@ -85,21 +135,28 @@ aumentKeelung :: [UInt 32] -> [UInt 32]
aumentKeelung = concatMap $ extractBytesKeelung 4

-- |Given two matrices, do modulo addition on each of the elements.
{-@ modMatrix :: { i:[_] | (len i) == 16 } -> { i:[_] | (len i) == 16 } -> { o:[_] | (len o) == 16 } @-}
modMatrix :: [Word32] -> [Word32] -> [Word32]
modMatrix = zipWith (+)
modMatrix a b = [
a!!0 + b!!0, a!!1 + b!!1, a!!2 + b!!2, a!!3 + b!!3,
a!!4 + b!!4, a!!5 + b!!5, a!!6 + b!!6, a!!7 + b!!7,
a!!8 + b!!8, a!!9 + b!!9, a!!10 + b!!10, a!!11 + b!!11,
a!!12 + b!!12, a!!13 + b!!13, a!!14 + b!!14, a!!15 + b!!15]

-- |Given two matrices, display the modulo addition on each of the elements.
{-@ modMatrixDisplay :: { i:[_] | (len i) == 16 } -> { i:[_] | (len i) == 16 } -> { o:[_] | (len o) == 16 } @-}
modMatrixDisplay :: [String] -> [String] -> [String]
modMatrixDisplay = zipWith displayMod
modMatrixDisplay a b =
[
a!!0 ++ " + " ++ b!!0, a!!1 ++ " + " ++ b!!1, a!!2 ++ " + " ++ b!!2, a!!3 ++ " + " ++ b!!3,
a!!4 ++ " + " ++ b!!4, a!!5 ++ " + " ++ b!!5, a!!6 ++ " + " ++ b!!6, a!!7 ++ " + " ++ b!!7,
a!!8 ++ " + " ++ b!!8, a!!9 ++ " + " ++ b!!9, a!!10 ++ " + " ++ b!!10, a!!11 ++ " + " ++ b!!11,
a!!12 ++ " + " ++ b!!12, a!!13 ++ " + " ++ b!!13, a!!14 ++ " + " ++ b!!14, a!!15 ++ " + " ++ b!!15]

-- |Given two matrices, do modulo addition on each of the elements using Keelung types.
modMatrixKeelung :: [UInt 32] -> [UInt 32] -> [UInt 32]
modMatrixKeelung = zipWith (+)

-- Append modulo addition symbol and a string to a given string.
displayMod :: String -> String -> String
displayMod a b = a ++ " + " ++ b

-- |Transpose a 4x4 matrix type.
{-@ transpose :: {i:[_] | len i == 16} -> {o:[_] | len o == 16} @-}
transpose :: [a] -> [a]
Expand Down
12 changes: 6 additions & 6 deletions test/unit/fast/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,15 @@ main = do
putStrLn ""

putStrLn "Littleendian inverse tests:"
putStrLn $ if extractBytes 4 littleendianOutput1 == littleendianInput1 then "OK" else "FAIL!"
putStrLn $ if extractBytes 4 littleendianOutput2 == littleendianInput2 then "OK" else "FAIL!"
putStrLn $ if extractBytes 4 littleendianOutput3 == littleendianInput3 then "OK" else "FAIL!"
putStrLn $ if extractBytes4 littleendianOutput1 == littleendianInput1 then "OK" else "FAIL!"
putStrLn $ if extractBytes4 littleendianOutput2 == littleendianInput2 then "OK" else "FAIL!"
putStrLn $ if extractBytes4 littleendianOutput3 == littleendianInput3 then "OK" else "FAIL!"
putStrLn ""

putStrLn "Salsa20 tests:"
putStrLn $ if salsa20Compute salsa20Input1 10 == salsa20Output1 then "OK" else "FAIL!"
putStrLn $ if salsa20Compute salsa20Input2 10 == salsa20Output2 then "OK" else "FAIL!"
putStrLn $ if salsa20Compute salsa20Input3 10 == salsa20Output3 then "OK" else "FAIL!"
putStrLn $ if salsa20Compute 10 salsa20Input1 == salsa20Output1 then "OK" else "FAIL!"
putStrLn $ if salsa20Compute 10 salsa20Input2 == salsa20Output2 then "OK" else "FAIL!"
putStrLn $ if salsa20Compute 10 salsa20Input3 == salsa20Output3 then "OK" else "FAIL!"
putStrLn ""

putStrLn "Expanded Salsa20 tests:"
Expand Down
10 changes: 5 additions & 5 deletions test/unit/keelung/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ main = do
-- https://github.com/btq-ag/keelung-compiler/issues/35

let doubleroundR_computed = doubleroundRCompute 10 demoInputWord32
doubleroundR_interpreted <- interpret gf181 (doubleroundRKeelung demoInputUInt32 10) [] []
doubleroundR_interpreted <- interpret gf181 (doubleroundRKeelung 10 demoInputUInt32) [] []
putStrLn $ if doubleroundR_computed == map fromIntegral doubleroundR_interpreted then "OK" else "FAIL!"

{-
Expand All @@ -102,7 +102,7 @@ main = do
-}

let doubleround10_computed = doubleroundRCompute 10 demoInputWord32
doubleround10_interpreted <- interpret gf181 (doubleroundRKeelung demoInputUInt32 10) [] []
doubleround10_interpreted <- interpret gf181 (doubleroundRKeelung 10 demoInputUInt32) [] []
putStrLn $ if doubleround10_computed == map fromIntegral doubleround10_interpreted then "OK" else "FAIL!"

{-
Expand All @@ -111,8 +111,8 @@ main = do
putStrLn $ if douleround10_constraints == 181940 then "OK" else "FAIL!"
-}

let core_computed = coreCompute demoInputWord32 10
core_interpreted <- interpret gf181 (coreKeelung demoInputUInt32 10) [] []
let core_computed = coreCompute 10 demoInputWord32
core_interpreted <- interpret gf181 (coreKeelung 10 demoInputUInt32) [] []
putStrLn $ if core_computed == map fromIntegral core_interpreted then "OK" else "FAIL!"

{-
Expand All @@ -121,7 +121,7 @@ main = do
putStrLn $ if core_constraints == 181956 then "OK" else "FAIL!"
-}

let salsa20_computed = salsa20Compute demoSalsa20InputWord32 10
let salsa20_computed = salsa20Compute 10 demoSalsa20InputWord32
salsa20_interpreted <- interpret gf181 (salsa20Keelung demoSalsa20InputUInt32) [] []
putStrLn $ if salsa20_computed == map fromIntegral salsa20_interpreted then "OK" else "FAIL!"

Expand Down

0 comments on commit fd7a736

Please sign in to comment.