Skip to content

Commit

Permalink
update arith with div by zero and div overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
zkronos73 committed Nov 20, 2024
1 parent dcf23ee commit 7a00cb7
Show file tree
Hide file tree
Showing 8 changed files with 527 additions and 677 deletions.
2 changes: 1 addition & 1 deletion pil/src/pil_helpers/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trace!(RomRow, RomTrace<F> {
});

trace!(ArithRow, ArithTrace<F> {
carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F,
carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, div_by_zero: F, div_overflow: F, inv_sum_all_bs: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F,
});

trace!(ArithTableRow, ArithTableTrace<F> {
Expand Down
83 changes: 64 additions & 19 deletions state-machines/arith/pil/arith.pil
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ require "arith_range_table.pil"
const int OP_LT_ABS = 0x9F;

airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_result = 0) {
// TODO: const int enable_div = 1, const int enable_32_bits = 1, const int enable_64_bits = 1

// NOTE:
// Divisions and remainders by 0 are done by QuickOps

const int CHUNK_SIZE = 2**16;
const int CHUNKS_INPUT = 4;
Expand Down Expand Up @@ -49,16 +45,64 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu
col witness main_div;
col witness main_mul;
col witness signed;

col witness div_by_zero;
col witness div_overflow;

main_div * (main_div - 1) === 0;
main_mul * (main_mul - 1) === 0;
main_mul * main_div === 0;
signed * (1 - signed) === 0;
div_by_zero * (1 - div_by_zero) === 0;
div_overflow * (1 - div_overflow) === 0;

// factor ab € {-1, 1}
fab === 1 - 2 * na - 2 * nb + 4 * na * nb;
na_fb === na * (1 - 2 * nb);
nb_fa === nb * (1 - 2 * na);

expr sum_all_bs = 0;
for (int i = 0; i < length(b); ++i) {
div_by_zero * b[i] === 0; // forces b must be zero when div_by_zero
sum_all_bs = sum_all_bs + b[i]; // all b are values of 16 bits (verified by range_check)
}

// when div_by_zero, a it's free, with this force a must be 0xFFFF
div_by_zero * (a[0] - 0xFFFF) === 0;
div_by_zero * (a[1] - 0xFFFF) === 0;
div_by_zero * (a[2] - (1 - m32) * 0xFFFF) === 0;
div_by_zero * (a[3] - (1 - m32) * 0xFFFF) === 0;

// when div_by_zero, a it's free, with this force a must be 0xFFFF
div_overflow * (b[0] - 0xFFFF) === 0;
div_overflow * (b[1] - 0xFFFF) === 0;
div_overflow * (b[2] - (1 - m32) * 0xFFFF) === 0;
div_overflow * (b[3] - (1 - m32) * 0xFFFF) === 0;

// when div_by_zero, a it's free, with this force a must be 0xFFFF
div_overflow * c[0] === 0;
div_overflow * (c[1] - m32 * 0x8000) === 0;
div_overflow * c[2] === 0;
div_overflow * (c[3] - (1 - m32) * 0x8000) === 0;

// b != 0 <==> sum_all_bs != 0
col witness inv_sum_all_bs;

// div = 0 => div_by_zero must be 0 => 0 (no need calculate inverse)
// div = 1 and div_by_zero = 0 => 1 calculate inverse to demostrate b != 0
// div = 1 and div_by_zero = 1 => 0 (no need calculate inverse)
(div - div_by_zero) * (1 - inv_sum_all_bs * sum_all_bs) === 0;

// div_by_zero only active for divisions
div_by_zero * (1 - div) === 0;

// div_overflow only active for signed divisions
div_overflow * (1 - div) === 0;
div_overflow * (1 - signed) === 0;

div_overflow * div_by_zero === 0;
div_by_zero * div_overflow === 0;

const expr eq[CHUNKS_OP];

// NOTE: Equations with m32 for multiplication not exists, because mul m32 it's an unsigned operation.
Expand Down Expand Up @@ -177,19 +221,19 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu

col witness op;

// div m32 sa sb primary secondary opcodes na nb np nr sext(c)
// -----------------------------------------------------------------------------
// 0 0 0 0 mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0
// 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 d3 =0 =0 =0 a3, d3
// 0 0 1 1 mul mulh (0xb4,0xb5) a3 b3 d3 =0 =0 =0 a3,b3, d3
// 0 1 0 0 mul_w *n/a* (0xb6,0xb7) =0 =0 =0 =0 c1 =0
// div m32 sa sb primary secondary opcodes na nb np nr sext(c)
// -------------------------------------------------------------------------------------
// 0 0 0 0 mulu muluh 0xb0 176 0xb1 177 =0 =0 =0 =0 =0 =0
// 0 0 1 0 *n/a* mulsuh 0xb2 - 0xb3 179 a3 =0 d3 =0 =0 =0 a3, d3
// 0 0 1 1 mul mulh 0xb4 180 0xb5 181 a3 b3 d3 =0 =0 =0 a3,b3, d3
// 0 1 0 0 mul_w *n/a* 0xb6 182 0xb7 - =0 =0 =0 =0 c1 =0

// div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2)
// ------------------------------------------------------------------------------
// 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0
// 1 0 1 1 div rem (0xba,0xbb) a3 b3 c3 d3 =0 =0 a3,b3,c3,d3
// 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 a1 d1 a1 ,d1
// 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 c1 d1 a1 d1 a1,b1,c1,d1
// div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2)
// ------------------------------------------------------------------------------------------
// 1 0 0 0 divu remu 0xb8 184 0xb9 185 =0 =0 =0 =0 =0 =0
// 1 0 1 1 div rem 0xba 186 0xbb 187 a3 b3 c3 d3 =0 =0 a3,b3,c3,d3
// 1 1 0 0 divu_w remu_w 0xbc 188 0xbd 189 =0 =0 =0 =0 a1 d1 a1 ,d1
// 1 1 1 1 div_w rem_w 0xbe 190 0xbf 191 a1 b1 c1 d1 a1 d1 a1,b1,c1,d1

// (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb
// (*2) sext affects to 32 bits result (bus), but in divisions a is used as result
Expand Down Expand Up @@ -217,7 +261,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu
main_div * (a[2] + a[3] * CHUNK_SIZE));
col witness bus_res1;

bus_res1 === sext * 0xFFFFFFFF + (1 - m32) * bus_res1_64;
bus_res1 === sext * 0xFFFF_FFFF + (1 - m32) * bus_res1_64;

m32 * bus_a1 === 0;
m32 * bus_b1 === 0;
Expand All @@ -229,7 +273,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu
bus_a0, bus_a1,
bus_b0, bus_b1,
bus_res0, bus_res1,
0], mul: multiplicity);
div_by_zero /*+ div_overflow*/], mul: multiplicity);

// TODO: remainder check
// lookup_assumes(operation_bus_id, [debug_main_step, signed * (OP_LT_ABS - OP_LT) + OP_LT,
Expand All @@ -244,7 +288,8 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu
col witness range_ab;
col witness range_cd;

arith_table_assumes(op, m32, div, na, nb, np, nr, sext, main_mul, main_div, signed, range_ab, range_cd);
arith_table_assumes(op, m32, div, na, nb, np, nr, sext, div_by_zero, div_overflow, main_mul,
main_div, signed, range_ab, range_cd);

const expr range_a3 = range_ab;
const expr range_a1 = range_ab + 26;
Expand Down
109 changes: 72 additions & 37 deletions state-machines/arith/pil/arith_table.pil
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,19 @@ const int ARITH_TABLE_ID = 331;

airtemplate ArithTable(int N = 2**7, int generate_table = 1) {

// TABLE
// op,m32|div|na|nb|np|nr|sext,range_ab,range_cd

// div m32 sa sb primary secondary opcodes na nb np nr sext(c)
// -----------------------------------------------------------------------------
// 0 0 0 0 mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0
// 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 d3 =0 =0 =0 a3, d3
// 0 0 1 1 mul mulh (0xb4,0xb5) a3 b3 d3 =0 =0 =0 a3,b3, d3
// 0 1 0 0 mul_w *n/a* (0xb6,0xb7) =0 =0 =0 =0 c1 =0

// div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2)
// ------------------------------------------------------------------------------
// 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0
// 1 0 1 1 div rem (0xba,0xbb) a3 b3 c3 d3 =0 =0 a3,b3,c3,d3
// 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 a1 d1 a1 ,d1
// 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 c1 d1 a1 d1 a1,b1,c1,d1
// div m32 sa sb primary secondary opcodes na nb np nr sext(c)
// -----------------------------------------------------------------------------------
// 0 0 0 0 mulu muluh 0xb0 176 0xb1 177 =0 =0 =0 =0 =0 =0
// 0 0 1 0 *n/a* mulsuh 0xb2 - 0xb3 179 a3 =0 d3 =0 =0 =0 a3, d3
// 0 0 1 1 mul mulh 0xb4 180 0xb5 181 a3 b3 d3 =0 =0 =0 a3,b3, d3
// 0 1 0 0 mul_w *n/a* 0xb6 182 0xb7 - =0 =0 =0 =0 c1 =0

// div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2)
// ------------------------------------------------------------------------------------
// 1 0 0 0 divu remu 0xb8 184 0xb9 185 =0 =0 =0 =0 =0 =0
// 1 0 1 1 div rem 0xba 186 0xbb 187 a3 b3 c3 d3 =0 =0 a3,b3,c3,d3
// 1 1 0 0 divu_w remu_w 0xbc 188 0xbd 189 =0 =0 =0 =0 a1 d1 a1 ,d1
// 1 1 1 1 div_w rem_w 0xbe 190 0xbf 191 a1 b1 c1 d1 a1 d1 a1,b1,c1,d1

const int OPS[14] = [0xb0, 0xb1, 0xb3, 0xb4, 0xb5, 0xb6, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf];

Expand All @@ -33,9 +30,9 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {
int aborted = 0;

if (generate_table) {
int air.op2row[512];
for (int i = 0; i < 512; ++i) {
op2row[i] = -1;
int air.op2row[2048];
for (int i = 0; i < 2048; ++i) {
op2row[i] = 255;
}
}

Expand Down Expand Up @@ -114,21 +111,44 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {
m32 = 1;
}

for (int icase = 0; icase < 32; ++icase) {
int na = 0;
int nb = 0;
int nr = 0;
int np = 0;
int sext = 0;

if (0x01 & icase) na = 1;
if (0x02 & icase) nb = 1;
if (0x04 & icase) np = 1;
if (0x08 & icase) nr = 1;
if (0x10 & icase) sext = 1;
for (int icase = 0; icase < 128; ++icase) {
const int na = (0x01 & icase) ? 1 : 0;
const int nb = (0x02 & icase) ? 1 : 0;
const int np = (0x04 & icase) ? 1 : 0;
const int nr = (0x08 & icase) ? 1 : 0;
const int sext = (0x10 & icase) ? 1 : 0;
const int div_by_zero = (0x20 & icase) ? 1 : 0;
const int div_overflow = (0x40 & icase) ? 1 : 0;

const int signed = (sa || sb) ? 1 : 0;

// division by zero (dividend: x, divisor: 0)
//
// DIV,DIVU 0xFFFF_FFFF_FFFF_FFFF
// REM,REMU x
// DIV_W,DIVU_W 0xFFFF_FFFF_FFFF_FFFF
// REM_W,REMU_W x

// division overflow 64 (divend: 0x8000_0000_0000_0000, divisor: 0xFFFF_FFFF_FFFF_FFFF)
//
// DIV 0x8000_0000_0000_0000
// REM 0

// division overflow 32 (divend: 0x8000_0000, divisor: 0xFFFF_FFFF)
//
// DIV_W 0xFFFF_FFFF_8000_0000
// REM_W 0

// div_by_zero
// signed:1 => na:1 nb:0 np = nr (0,1)
// signed:0 => na:0 nb:0 np:0 nr:0

// div_overflow
// signed:1 => na:1 nb:1 np:1 nr:0 sext:0

if (div_by_zero && (!div || nb || np != nr || signed != na)) continue;
if (div_by_zero && main_div && m32 && !sext) continue;
if (div_overflow && (!div || !signed || !na || !nb || !np || nr)) continue;
if (sext && !m32) continue;
if (nr && !div) continue;
if (na && !sa) continue;
Expand All @@ -137,8 +157,8 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {
if (nr && !sa && !sb) continue;
if (np && na == nb && !div) continue;
if (np && !na && !nb && !nr && div) continue;
if (na && !nb && !nr && !np && div) continue;
if (np && na && nb) continue;
if (na && !nb && !nr && !np && div && !div_by_zero) continue;
if (np && na && nb && !div_overflow) continue;
if (!np & nr) continue;
if (m32 && signed && main_div && na != sext) continue;
if (m32 && signed && div && !main_div && nr != sext) continue;
Expand Down Expand Up @@ -186,7 +206,9 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {
}
}
const int flags = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext +
128 * main_mul + 256 * main_div + 512 * signed;
128 * div_by_zero + 256 * div_overflow + 512 * main_mul +
1024 * main_div + 2048 * signed;

int range_ab = (range_a3 + range_a1) * 3 + range_b3 + range_b1;
if ((range_a1 + range_b1) > 0) {
range_ab = range_ab + 8;
Expand All @@ -203,7 +225,11 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {
RANGE_CD[index] = range_cd;

if (generate_table) {
op2row[(opcode - 0xb0) * 32 + icase] = index;
println(`OP:${opcode} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} m32:${m32} div:${div}`,
`div_by_zero:${div_by_zero} div_overflow:${div_overflow} sa:${sa} sb:${sb} main_mul:${main_mul}`,
`main_div:${main_div} signed:${signed} range_ab:${range_ab} range_cd:${range_cd} index:${(opcode - 0xb0) * 128 + icase} icase:${icase}`);

op2row[(opcode - 0xb0) * 128 + icase] = index;
code = code + `[${opcode}, ${flags}, ${range_ab}, ${range_cd}],`;
}
++index;
Expand All @@ -212,9 +238,16 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {
const int size = index;

println("ARITH_TABLE SIZE: ", size);
assert(size < 256);

if (generate_table) {
println(`pub const ROWS: usize = ${size};`);
println("pub static ARITH_TABLE_ROWS: [i16; 512] = [", op2row, "];");
println("const __: u8 = 255;");
string _op2row = "";
for (int i = 0; i < 2048; ++i) {
_op2row = _op2row + ((op2row[i] == 255) ? "__":string(op2row[i])) + ",";
}
println("pub static ARITH_TABLE_ROWS: [u8; 2048] = [", _op2row, "];");
println(`pub static ARITH_TABLE: [[u16; 4]; ROWS] = [${code}];`);
}

Expand All @@ -238,11 +271,13 @@ airtemplate ArithTable(int N = 2**7, int generate_table = 1) {

function arith_table_assumes( const expr op, const expr flag_m32, const expr flag_div, const expr flag_na,
const expr flag_nb, const expr flag_np, const expr flag_nr, const expr flag_sext,
const expr flag_div_by_zero, const expr flag_div_overflow,
const expr flag_main_mul, const expr flag_main_div, const expr flag_signed,
const expr range_ab, const expr range_cd) {

lookup_assumes(ARITH_TABLE_ID, cols: [ op, flag_m32 + 2 * flag_div + 4 * flag_na + 8 * flag_nb +
16 * flag_np + 32 * flag_nr + 64 * flag_sext +
128 * flag_main_mul + 256 * flag_main_div + 512 * flag_signed,
128 * flag_div_by_zero + 256 * flag_div_overflow +
512 * flag_main_mul + 1024 * flag_main_div + 2048 * flag_signed,
range_ab, range_cd]);
}
22 changes: 20 additions & 2 deletions state-machines/arith/src/arith_full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl<F: Field> ArithFullSM<F> {

let mut aop = ArithOperation::new();
for (irow, input) in input.iter().enumerate() {
// println!("#{} ARITH op:0x{:X} a:0x{:X} b:0x{:X}", irow, input.opcode, input.a, input.b);
println!("#{} ARITH op:0x{:X} a:0x{:X} b:0x{:X}", irow, input.opcode, input.a, input.b);
aop.calculate(input.opcode, input.a, input.b);
let mut t: ArithRow<F> = Default::default();
for i in [0, 2] {
Expand Down Expand Up @@ -136,8 +136,24 @@ impl<F: Field> ArithFullSM<F> {
t.debug_main_step = F::from_canonical_u64(input.step);
t.range_ab = F::from_canonical_u8(aop.range_ab);
t.range_cd = F::from_canonical_u8(aop.range_cd);
t.div_by_zero = F::from_bool(aop.div_by_zero);
t.div_overflow = F::from_bool(aop.div_overflow);
t.inv_sum_all_bs = if aop.div && !aop.div_by_zero {
F::from_canonical_u64(aop.b[0] + aop.b[1] + aop.b[2] + aop.b[3]).inverse()
} else {
F::zero()
};

table_inputs.add_use(aop.op, aop.na, aop.nb, aop.np, aop.nr, aop.sext);
table_inputs.add_use(
aop.op,
aop.na,
aop.nb,
aop.np,
aop.nr,
aop.sext,
aop.div_by_zero,
aop.div_overflow,
);

t.fab = if aop.na != aop.nb { F::neg_one() } else { F::one() };
// na * (1 - 2 * nb);
Expand Down Expand Up @@ -204,6 +220,8 @@ impl<F: Field> ArithFullSM<F> {
false,
false,
false,
false,
false,
);
}
timer_stop_and_log_trace!(ARITH_PADDING);
Expand Down
Loading

0 comments on commit 7a00cb7

Please sign in to comment.