-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cases] add eval.mmm_mem (to be fixed)
- Loading branch information
Showing
4 changed files
with
312 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
{ linkerScript | ||
, makeBuilder | ||
, t1main | ||
}: | ||
|
||
let | ||
builder = makeBuilder { casePrefix = "eval"; }; | ||
build_ntt = caseName /* must be consistent with attr name */ : len: kernel_src: | ||
builder { | ||
caseName = caseName; | ||
|
||
src = ./.; | ||
|
||
passthru.featuresRequired = { }; | ||
|
||
buildPhase = '' | ||
runHook preBuild | ||
$CC -T${linkerScript} -DLEN=${toString len} \ | ||
${./mmm_main.c} ${kernel_src} \ | ||
${t1main} \ | ||
-o $pname.elf | ||
runHook postBuild | ||
''; | ||
|
||
meta.description = "test case 'ntt'"; | ||
}; | ||
|
||
in { | ||
mmm_mem_4096_vl4096 = build_ntt "mmm_mem_4096_vl4096" 4096 ./mmm_4096_vl4096.S; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
.text | ||
.balign 16 | ||
.globl mmm | ||
.type mmm,@function | ||
# assume VLEN >= 4096, BN = 4096, SEW = 16 * 2 = 32 | ||
# we only support LMUL = 1 for now | ||
# P, A, B, AB should have 384 elements | ||
mmm: | ||
# quite SIMD | ||
li t0, 128 # in case way > 31 | ||
vsetvli zero, t0, e32, m1, ta, ma | ||
# stride | ||
li t1, 12 | ||
# start loop of niter + 1 times | ||
li t4,0 | ||
1: | ||
# AB = B_i*A + AB | ||
# !!!!!! important: lw here assumes SEW = 32 | ||
# T0 is used in vmacc, do not use for temp now! | ||
lw t0, 0(a2) | ||
addi a2, a2, 4 # advance B by a SEW | ||
|
||
# carry for ABV_0 | ||
vmv.v.i v30,0 | ||
# loop variable | ||
li t5,0 | ||
|
||
# --- | ||
# macc (V=a1, VV=v10, VVN=10, ngroupreg=3) | ||
# --- | ||
|
||
# load one group of values from arg | ||
# offset of one group | ||
# !!! important: assume nreg = 8 and sew = 32 | ||
# log(8) + log(32/8) = 5 | ||
slli t2,t5,5 | ||
add t3,t2,a1 | ||
vlsseg3e32.v v10, (t3), t1 | ||
add t3,t2,a0 | ||
vlsseg3e32.v v20, (t3), t1 | ||
vmacc.vx v20, t0, v10 | ||
vmacc.vx v21, t0, v11 | ||
vmacc.vx v22, t0, v12 | ||
# store one group of AB | ||
vssseg3e32.v v20, (t3), t1 | ||
|
||
# --- | ||
# propagate_niter | ||
# --- | ||
|
||
# start loop of niter + 1 times | ||
# use T2 as outer loop index | ||
li t2,0 | ||
9: | ||
# mask | ||
# set TV2 for every propagate() | ||
# set TV2 every time (see slide1up below) | ||
li t0,65535 | ||
vmv.v.x v31,t0 | ||
|
||
# carry for ABV_0 | ||
vmv.v.i v30,0 | ||
|
||
# loop variable | ||
li t5,0 | ||
|
||
# load last group of values from arg | ||
# offset of last group | ||
# !!! important: assume nreg = 8 and sew = 32 | ||
# log(8) + log(32/8) = 5 | ||
# LOOP2 is now ngroup - 1 | ||
slli t3,t5,5 | ||
add t3,t3,a0 | ||
vlsseg3e32.v v20, (t3), t1 | ||
|
||
# --- | ||
# propagate (j=0, ngroupreg=3) | ||
# --- | ||
|
||
vadd.vv v20, v20, v30 | ||
# save carry in TV | ||
vsrl.vi v30, v20, 16 | ||
# mod 2 ** 16 | ||
vand.vv v20, v20, v31 | ||
vadd.vv v21, v21, v30 | ||
|
||
# --- | ||
# propagate (j=1, ngroupreg=3) | ||
# --- | ||
|
||
# save carry in TV | ||
vsrl.vi v30, v21, 16 | ||
# mod 2 ** 16 | ||
vand.vv v21, v21, v31 | ||
vadd.vv v22, v22, v30 | ||
|
||
# --- | ||
# propagate (j=2, ngroupreg=3) | ||
# --- | ||
|
||
# save carry in TV | ||
vsrl.vi v30, v22, 16 | ||
# mod 2 ** 16 | ||
vand.vv v22, v22, v31 | ||
# store last group of AB | ||
vssseg3e32.v v20, (t3), t1 | ||
|
||
# update carry of AB_{ntotalreg - 1} to AB_0 | ||
vlse32.v v20, (a0), t1 | ||
vslide1up.vx v31, v30, zero | ||
vadd.vv v20, v20, v31 | ||
vsse32.v v20, (a0), t1 | ||
addi t2,t2,1 | ||
li t0,128 | ||
bne t2,t0,9b | ||
# !!!!!! important: lw here assumes SEW = 32 | ||
# T0 is used in vmacc, do not use for temp now! | ||
lw t0, 0(a0) | ||
mul t0, t0, a4 | ||
# mod 2 ** 16 | ||
# !!!! important: here we assume SEW = 32 and XLEN = 64 | ||
sll t0, t0, 16 | ||
srl t0, t0, 16 | ||
|
||
# loop variable | ||
li t5,0 | ||
|
||
# --- | ||
# macc (V=a3, VV=v0, VVN=0, ngroupreg=3) | ||
# --- | ||
|
||
# load one group of values from arg | ||
# offset of one group | ||
# !!! important: assume nreg = 8 and sew = 32 | ||
# log(8) + log(32/8) = 5 | ||
slli t2,t5,5 | ||
add t3,t2,a3 | ||
vlsseg3e32.v v0, (t3), t1 | ||
add t3,t2,a0 | ||
vlsseg3e32.v v20, (t3), t1 | ||
vmacc.vx v20, t0, v0 | ||
vmacc.vx v21, t0, v1 | ||
vmacc.vx v22, t0, v2 | ||
# store one group of AB | ||
vssseg3e32.v v20, (t3), t1 | ||
|
||
# --- | ||
# propagate_niter | ||
# --- | ||
|
||
# start loop of niter + 1 times | ||
# use T2 as outer loop index | ||
li t2,0 | ||
9: | ||
# mask | ||
# set TV2 for every propagate() | ||
# set TV2 every time (see slide1up below) | ||
li t0,65535 | ||
vmv.v.x v31,t0 | ||
|
||
# carry for ABV_0 | ||
vmv.v.i v30,0 | ||
|
||
# loop variable | ||
li t5,0 | ||
|
||
# load last group of values from arg | ||
# offset of last group | ||
# !!! important: assume nreg = 8 and sew = 32 | ||
# log(8) + log(32/8) = 5 | ||
# LOOP2 is now ngroup - 1 | ||
slli t3,t5,5 | ||
add t3,t3,a0 | ||
vlsseg3e32.v v20, (t3), t1 | ||
|
||
# --- | ||
# propagate (j=0, ngroupreg=3) | ||
# --- | ||
|
||
vadd.vv v20, v20, v30 | ||
# save carry in TV | ||
vsrl.vi v30, v20, 16 | ||
# mod 2 ** 16 | ||
vand.vv v20, v20, v31 | ||
vadd.vv v21, v21, v30 | ||
|
||
# --- | ||
# propagate (j=1, ngroupreg=3) | ||
# --- | ||
|
||
# save carry in TV | ||
vsrl.vi v30, v21, 16 | ||
# mod 2 ** 16 | ||
vand.vv v21, v21, v31 | ||
vadd.vv v22, v22, v30 | ||
|
||
# --- | ||
# propagate (j=2, ngroupreg=3) | ||
# --- | ||
|
||
# save carry in TV | ||
vsrl.vi v30, v22, 16 | ||
# mod 2 ** 16 | ||
vand.vv v22, v22, v31 | ||
# store last group of AB | ||
vssseg3e32.v v20, (t3), t1 | ||
|
||
# update carry of AB_{ntotalreg - 1} to AB_0 | ||
vlse32.v v20, (a0), t1 | ||
vslide1up.vx v31, v30, zero | ||
vadd.vv v20, v20, v31 | ||
vsse32.v v20, (a0), t1 | ||
addi t2,t2,1 | ||
li t0,128 | ||
bne t2,t0,9b | ||
|
||
# update carry of AB_2 to AB_0 | ||
# since we need to substract AB_0 | ||
vlse32.v v20, (a0), t1 | ||
# AB / word | ||
vslide1down.vx v30, v20, zero | ||
# do not need vsse now | ||
# just store it in TV for move | ||
|
||
# ----- | ||
# move | ||
# ----- | ||
|
||
# move AB_1 to AB_0, AB_2 to AB_1, ... , AB_0 (in TV now) to AB_2 | ||
# loop variable | ||
li t5,0 | ||
# load last group of values from arg | ||
# offset of last group | ||
# !!! important: assume nreg = 8 and sew = 32 | ||
# log(8) + log(32/8) = 5 | ||
# LOOP2 is now ngroup - 1 | ||
slli t2,t5,5 | ||
# then offset by 1 element | ||
addi t2,t2,4 | ||
add t3,t2,a0 | ||
vlsseg2e32.v v20, (t3), t1 | ||
# move AB_0 to AB_2 | ||
vmv.v.v v22, v30 | ||
|
||
# back to original offset | ||
addi t3,t3,-4 | ||
vssseg3e32.v v20, (t3), t1 | ||
|
||
addi t4,t4,1 | ||
li t0,257 | ||
|
||
bne t4,t0,1b | ||
|
||
ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#include <stdint.h> | ||
#include <stdlib.h> | ||
#include <stdio.h> | ||
|
||
#ifndef LEN | ||
// define then error, to make clangd happy without configuration | ||
#define LEN 1024 | ||
#error "LEN not defined" | ||
#endif | ||
|
||
void mmm(uint32_t* r, const uint32_t* a, const uint32_t* b, const uint32_t* p, const uint32_t mu); | ||
|
||
void test() { | ||
int words = (LEN) / 16 + 4; | ||
uint32_t *r = (uint32_t *) malloc(words * sizeof(uint32_t)); | ||
uint32_t *a = (uint32_t *) malloc(words * sizeof(uint32_t)); | ||
uint32_t *b = (uint32_t *) malloc(words * sizeof(uint32_t)); | ||
uint32_t *p = (uint32_t *) malloc(words * sizeof(uint32_t)); | ||
uint32_t mu = 0xca1b; | ||
mmm(r, a, b, p, mu); | ||
for (int i = 0; i < words; i++) { | ||
printf("%04X ", r[i]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters