From 529dcd537d92abdadfbc019af060cb8c19f0bd75 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Tue, 19 Nov 2024 15:30:17 +0100 Subject: [PATCH 1/3] fix innersum bgv --- schemes/bgv/evaluator.go | 37 +++++++++++++++++++++++++++++++++++++ schemes/bgv/params.go | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/schemes/bgv/evaluator.go b/schemes/bgv/evaluator.go index 554f7c5f3..fb94c00c0 100644 --- a/schemes/bgv/evaluator.go +++ b/schemes/bgv/evaluator.go @@ -1505,6 +1505,43 @@ func (eval Evaluator) RotateHoistedLazyNew(level int, rotations []int, op0 *rlwe return } +// InnerSum computes the inner sum of the underlying slots (see [rlwe.Evaluator.InnerSum]). +// NB: in the slot encoding of BGV/BFV, the underlying N slots are arranged as 2 rows of N/2 slots. +// If n*batchSize < N/2, InnerSum computes the [rlwe.Evaluator.InnerSum] of each row separately. +// If n*batchSize = N, InnerSum computes the [rlwe.Evaluator.InnerSum] on the concatenation of both rows. +// NOTE: In this case, InnerSum performs an addition and a [Evaluator.RotateRowsNew] on top +// Otherwise, InnerSum returns an error. +func (eval Evaluator) InnerSum(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) { + N := eval.parameters.N() + halfN := N >> 1 + l := n * batchSize + + if l > halfN { + if l != N { + return fmt.Errorf("innersum: n*batchSize=%d > N/2=%d and n*batchSize != N=%d", l, halfN, N) + } + + if err = eval.Evaluator.InnerSum(ctIn, batchSize, n/2, opOut); err != nil { + return + } + + var ctRot *rlwe.Ciphertext + ctRot, err = eval.RotateRowsNew(opOut) + if err != nil { + return + } + + if err = eval.Add(opOut, ctRot, opOut); err != nil { + return + } + + return + } + + err = eval.Evaluator.InnerSum(ctIn, batchSize, n, opOut) + return +} + // MatchScalesAndLevel updates the both input ciphertexts to ensures that their scale matches. // To do so it computes t0 * a = opOut * b such that: // - ct0.Scale * a = opOut.Scale: make the scales match. diff --git a/schemes/bgv/params.go b/schemes/bgv/params.go index 066263837..8ae690de7 100644 --- a/schemes/bgv/params.go +++ b/schemes/bgv/params.go @@ -257,7 +257,7 @@ func (p Parameters) GaloisElementForRowRotation() uint64 { // InnerSum operation with parameters batch and n. func (p Parameters) GaloisElementsForInnerSum(batch, n int) (galEls []uint64) { galEls = rlwe.GaloisElementsForInnerSum(p, batch, n) - if n > p.N()>>1 { + if n*batch > p.N()>>1 { galEls = append(galEls, p.GaloisElementForRowRotation()) } return From 141c402c9134739b22f1d9bf17f56fe46cac2672 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Thu, 21 Nov 2024 09:26:19 +0100 Subject: [PATCH 2/3] tests + use maxslots instead of N --- schemes/bgv/bgv_test.go | 75 ++++++++++++++++++++++++++++++++++++++++ schemes/bgv/evaluator.go | 16 +++------ schemes/bgv/params.go | 2 +- 3 files changed, 81 insertions(+), 12 deletions(-) diff --git a/schemes/bgv/bgv_test.go b/schemes/bgv/bgv_test.go index 0dca91dd0..e46ddbad3 100644 --- a/schemes/bgv/bgv_test.go +++ b/schemes/bgv/bgv_test.go @@ -14,6 +14,7 @@ import ( "github.com/tuneinsight/lattigo/v6/core/rlwe" "github.com/tuneinsight/lattigo/v6/ring" + "github.com/tuneinsight/lattigo/v6/utils" ) var flagPrintNoise = flag.Bool("print-noise", false, "print the residual noise") @@ -665,6 +666,80 @@ func testEvaluatorBvg(tc *TestContext, t *testing.T) { } }) } + + // Naive implementation of the inner sum for reference + innersum := func(values []uint64, n, batchSize int) { + tmp := make([]uint64, len(values)) + copy(tmp, values) + for i := 1; i < n; i++ { + rot := utils.RotateSlice(tmp, i*batchSize) + for j := range values { + values[j] = (values[j] + rot[j]) % tc.Params.PlaintextModulus() + } + } + } + + for _, N := range []int{tc.Params.N(), tc.Params.MaxSlots()} { + for _, lvl := range testLevel { + t.Run(name("Evaluator/InnerSum/N slots", tc, lvl), func(t *testing.T) { + if lvl == 0 { + t.Skip("Skipping: Level = 0") + } + n := N >> 2 + batchSize := 1 << 2 + + galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n) + evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)) + + want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + + innersum(want, n, batchSize) + + receiver := NewCiphertext(tc.Params, 1, lvl) + + require.NoError(t, evl.InnerSum(ciphertext0, batchSize, n, receiver)) + + have := make([]uint64, len(want)) + require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have)) + + for i := 0; i < len(want); i += n * batchSize { + require.Equal(t, want[i:i+batchSize], have[i:i+batchSize]) + } + }) + } + } + + for _, lvl := range testLevel { + t.Run(name("Evaluator/InnerSum/N/2 slots", tc, lvl), func(t *testing.T) { + if lvl == 0 { + t.Skip("Skipping: Level = 0") + } + n := 7 + batchSize := 13 + l := n * batchSize + halfN := tc.Params.MaxSlots() >> 1 + + galEls := tc.Params.GaloisElementsForInnerSum(batchSize, n) + evl := tc.Evl.WithKey(rlwe.NewMemEvaluationKeySet(nil, tc.Kgen.GenGaloisKeysNew(galEls, tc.Sk)...)) + + want, _, ciphertext0 := NewTestVector(tc.Params, tc.Ecd, tc.Enc, lvl, tc.Params.NewScale(3)) + + innersum(want[:halfN], n, batchSize) + innersum(want[halfN:], n, batchSize) + + receiver := NewCiphertext(tc.Params, 1, lvl) + + require.NoError(t, evl.InnerSum(ciphertext0, batchSize, n, receiver)) + + have := make([]uint64, len(want)) + require.NoError(t, tc.Ecd.Decode(tc.Dec.DecryptNew(receiver), have)) + + for i, j := 0, halfN; i < halfN; i, j = i+l, j+l { + require.Equal(t, want[i:i+batchSize], have[i:i+batchSize]) + require.Equal(t, want[j:j+batchSize], have[j:j+batchSize]) + } + }) + } } func testEvaluatorBfv(tc *TestContext, t *testing.T) { diff --git a/schemes/bgv/evaluator.go b/schemes/bgv/evaluator.go index fb94c00c0..7ab6a4441 100644 --- a/schemes/bgv/evaluator.go +++ b/schemes/bgv/evaluator.go @@ -1507,20 +1507,14 @@ func (eval Evaluator) RotateHoistedLazyNew(level int, rotations []int, op0 *rlwe // InnerSum computes the inner sum of the underlying slots (see [rlwe.Evaluator.InnerSum]). // NB: in the slot encoding of BGV/BFV, the underlying N slots are arranged as 2 rows of N/2 slots. -// If n*batchSize < N/2, InnerSum computes the [rlwe.Evaluator.InnerSum] of each row separately. -// If n*batchSize = N, InnerSum computes the [rlwe.Evaluator.InnerSum] on the concatenation of both rows. -// NOTE: In this case, InnerSum performs an addition and a [Evaluator.RotateRowsNew] on top -// Otherwise, InnerSum returns an error. +// If n*batchSize is a multiple of N, InnerSum computes the [rlwe.Evaluator.InnerSum] on the N slots. +// NOTE: In this case, InnerSum performs an addition and a [Evaluator.RotateRowsNew] on top. +// Otherwise, InnerSum computes the [rlwe.Evaluator.InnerSum] of each row separately. func (eval Evaluator) InnerSum(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *rlwe.Ciphertext) (err error) { - N := eval.parameters.N() - halfN := N >> 1 + N := eval.parameters.MaxSlots() l := n * batchSize - if l > halfN { - if l != N { - return fmt.Errorf("innersum: n*batchSize=%d > N/2=%d and n*batchSize != N=%d", l, halfN, N) - } - + if l%N == 0 { if err = eval.Evaluator.InnerSum(ctIn, batchSize, n/2, opOut); err != nil { return } diff --git a/schemes/bgv/params.go b/schemes/bgv/params.go index 8ae690de7..84d405096 100644 --- a/schemes/bgv/params.go +++ b/schemes/bgv/params.go @@ -257,7 +257,7 @@ func (p Parameters) GaloisElementForRowRotation() uint64 { // InnerSum operation with parameters batch and n. func (p Parameters) GaloisElementsForInnerSum(batch, n int) (galEls []uint64) { galEls = rlwe.GaloisElementsForInnerSum(p, batch, n) - if n*batch > p.N()>>1 { + if n*batch%p.MaxSlots() == 0 { galEls = append(galEls, p.GaloisElementForRowRotation()) } return From b67c72b7bada102ea69debf87399e38f9287e4a7 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Thu, 21 Nov 2024 10:06:00 +0100 Subject: [PATCH 3/3] return ctIn when n=1, batchSize=k*N --- schemes/bgv/evaluator.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/schemes/bgv/evaluator.go b/schemes/bgv/evaluator.go index 7ab6a4441..7b69557a8 100644 --- a/schemes/bgv/evaluator.go +++ b/schemes/bgv/evaluator.go @@ -1515,6 +1515,13 @@ func (eval Evaluator) InnerSum(ctIn *rlwe.Ciphertext, batchSize, n int, opOut *r l := n * batchSize if l%N == 0 { + if n == 1 { + if ctIn != opOut { + opOut.Copy(ctIn) + } + return + } + if err = eval.Evaluator.InnerSum(ctIn, batchSize, n/2, opOut); err != nil { return }