Skip to content

Commit

Permalink
[cases] add eval.mmm_mem (to be fixed)
Browse files Browse the repository at this point in the history
  • Loading branch information
SharzyL committed Nov 21, 2024
1 parent 8fd3f88 commit 2a4f422
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 1 deletion.
32 changes: 32 additions & 0 deletions tests/eval/_mmm_mem/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{ linkerScript
, makeBuilder
, t1main
}:

let
builder = makeBuilder { casePrefix = "eval"; };
build_ntt = caseName /* must be consistent with attr name */ : len: kernel_src:
builder {
caseName = caseName;

src = ./.;

passthru.featuresRequired = { };

buildPhase = ''
runHook preBuild
$CC -T${linkerScript} -DLEN=${toString len} \
${./mmm_main.c} ${kernel_src} \
${t1main} \
-o $pname.elf
runHook postBuild
'';

meta.description = "test case 'ntt'";
};

in {
mmm_mem_4096_vl4096 = build_ntt "mmm_mem_4096_vl4096" 4096 ./mmm_4096_vl4096.S;
}
254 changes: 254 additions & 0 deletions tests/eval/_mmm_mem/mmm_4096_vl4096.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
.text
.balign 16
.globl mmm
.type mmm,@function
# assume VLEN >= 4096, BN = 4096, SEW = 16 * 2 = 32
# we only support LMUL = 1 for now
# P, A, B, AB should have 384 elements
mmm:
# quite SIMD
li t0, 128 # in case way > 31
vsetvli zero, t0, e32, m1, ta, ma
# stride
li t1, 12
# start loop of niter + 1 times
li t4,0
1:
# AB = B_i*A + AB
# !!!!!! important: lw here assumes SEW = 32
# T0 is used in vmacc, do not use for temp now!
lw t0, 0(a2)
addi a2, a2, 4 # advance B by a SEW

# carry for ABV_0
vmv.v.i v30,0
# loop variable
li t5,0

# ---
# macc (V=a1, VV=v10, VVN=10, ngroupreg=3)
# ---

# load one group of values from arg
# offset of one group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
slli t2,t5,5
add t3,t2,a1
vlsseg3e32.v v10, (t3), t1
add t3,t2,a0
vlsseg3e32.v v20, (t3), t1
vmacc.vx v20, t0, v10
vmacc.vx v21, t0, v11
vmacc.vx v22, t0, v12
# store one group of AB
vssseg3e32.v v20, (t3), t1

# ---
# propagate_niter
# ---

# start loop of niter + 1 times
# use T2 as outer loop index
li t2,0
9:
# mask
# set TV2 for every propagate()
# set TV2 every time (see slide1up below)
li t0,65535
vmv.v.x v31,t0

# carry for ABV_0
vmv.v.i v30,0

# loop variable
li t5,0

# load last group of values from arg
# offset of last group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
# LOOP2 is now ngroup - 1
slli t3,t5,5
add t3,t3,a0
vlsseg3e32.v v20, (t3), t1

# ---
# propagate (j=0, ngroupreg=3)
# ---

vadd.vv v20, v20, v30
# save carry in TV
vsrl.vi v30, v20, 16
# mod 2 ** 16
vand.vv v20, v20, v31
vadd.vv v21, v21, v30

# ---
# propagate (j=1, ngroupreg=3)
# ---

# save carry in TV
vsrl.vi v30, v21, 16
# mod 2 ** 16
vand.vv v21, v21, v31
vadd.vv v22, v22, v30

# ---
# propagate (j=2, ngroupreg=3)
# ---

# save carry in TV
vsrl.vi v30, v22, 16
# mod 2 ** 16
vand.vv v22, v22, v31
# store last group of AB
vssseg3e32.v v20, (t3), t1

# update carry of AB_{ntotalreg - 1} to AB_0
vlse32.v v20, (a0), t1
vslide1up.vx v31, v30, zero
vadd.vv v20, v20, v31
vsse32.v v20, (a0), t1
addi t2,t2,1
li t0,128
bne t2,t0,9b
# !!!!!! important: lw here assumes SEW = 32
# T0 is used in vmacc, do not use for temp now!
lw t0, 0(a0)
mul t0, t0, a4
# mod 2 ** 16
# !!!! important: here we assume SEW = 32 and XLEN = 64
sll t0, t0, 16
srl t0, t0, 16

# loop variable
li t5,0

# ---
# macc (V=a3, VV=v0, VVN=0, ngroupreg=3)
# ---

# load one group of values from arg
# offset of one group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
slli t2,t5,5
add t3,t2,a3
vlsseg3e32.v v0, (t3), t1
add t3,t2,a0
vlsseg3e32.v v20, (t3), t1
vmacc.vx v20, t0, v0
vmacc.vx v21, t0, v1
vmacc.vx v22, t0, v2
# store one group of AB
vssseg3e32.v v20, (t3), t1

# ---
# propagate_niter
# ---

# start loop of niter + 1 times
# use T2 as outer loop index
li t2,0
9:
# mask
# set TV2 for every propagate()
# set TV2 every time (see slide1up below)
li t0,65535
vmv.v.x v31,t0

# carry for ABV_0
vmv.v.i v30,0

# loop variable
li t5,0

# load last group of values from arg
# offset of last group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
# LOOP2 is now ngroup - 1
slli t3,t5,5
add t3,t3,a0
vlsseg3e32.v v20, (t3), t1

# ---
# propagate (j=0, ngroupreg=3)
# ---

vadd.vv v20, v20, v30
# save carry in TV
vsrl.vi v30, v20, 16
# mod 2 ** 16
vand.vv v20, v20, v31
vadd.vv v21, v21, v30

# ---
# propagate (j=1, ngroupreg=3)
# ---

# save carry in TV
vsrl.vi v30, v21, 16
# mod 2 ** 16
vand.vv v21, v21, v31
vadd.vv v22, v22, v30

# ---
# propagate (j=2, ngroupreg=3)
# ---

# save carry in TV
vsrl.vi v30, v22, 16
# mod 2 ** 16
vand.vv v22, v22, v31
# store last group of AB
vssseg3e32.v v20, (t3), t1

# update carry of AB_{ntotalreg - 1} to AB_0
vlse32.v v20, (a0), t1
vslide1up.vx v31, v30, zero
vadd.vv v20, v20, v31
vsse32.v v20, (a0), t1
addi t2,t2,1
li t0,128
bne t2,t0,9b

# update carry of AB_2 to AB_0
# since we need to substract AB_0
vlse32.v v20, (a0), t1
# AB / word
vslide1down.vx v30, v20, zero
# do not need vsse now
# just store it in TV for move

# -----
# move
# -----

# move AB_1 to AB_0, AB_2 to AB_1, ... , AB_0 (in TV now) to AB_2
# loop variable
li t5,0
# load last group of values from arg
# offset of last group
# !!! important: assume nreg = 8 and sew = 32
# log(8) + log(32/8) = 5
# LOOP2 is now ngroup - 1
slli t2,t5,5
# then offset by 1 element
addi t2,t2,4
add t3,t2,a0
vlsseg2e32.v v20, (t3), t1
# move AB_0 to AB_2
vmv.v.v v22, v30

# back to original offset
addi t3,t3,-4
vssseg3e32.v v20, (t3), t1

addi t4,t4,1
li t0,257

bne t4,t0,1b

ret
24 changes: 24 additions & 0 deletions tests/eval/_mmm_mem/mmm_main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>

#ifndef LEN
// define then error, to make clangd happy without configuration
#define LEN 1024
#error "LEN not defined"
#endif

void mmm(uint32_t* r, const uint32_t* a, const uint32_t* b, const uint32_t* p, const uint32_t mu);

void test() {
int words = (LEN) / 16 + 4;
uint32_t *r = (uint32_t *) malloc(words * sizeof(uint32_t));
uint32_t *a = (uint32_t *) malloc(words * sizeof(uint32_t));
uint32_t *b = (uint32_t *) malloc(words * sizeof(uint32_t));
uint32_t *p = (uint32_t *) malloc(words * sizeof(uint32_t));
uint32_t mu = 0xca1b;
mmm(r, a, b, p, mu);
for (int i = 0; i < words; i++) {
printf("%04X ", r[i]);
}
}
3 changes: 2 additions & 1 deletion tests/eval/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ let
autoCases = findAndBuild ./. build;

nttCases = callPackage ./_ntt { };
mmmCases = callPackage ./_mmm_mem { };
in
autoCases // nttCases
autoCases // nttCases // mmmCases

0 comments on commit 2a4f422

Please sign in to comment.