From c7a5ba569dd9f594c8102a48ac00e7f34a022d00 Mon Sep 17 00:00:00 2001 From: lwj Date: Wed, 18 Sep 2024 10:10:20 +0800 Subject: [PATCH] [CHEETAH] Optimize the MulAA communication. (#850) Previous impl realizes the MulAA via two OLEs for computing two terms x0*y1 and x1*y0. This will introduce a larger communication overhead. We switch to another strategy by computing the sum x0*y1+x1*y0 homomorphically. To further utilize the CPU resources, we split a long vector into two subtasks and to let Rank0 and Rank1 to handle each half. # Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility: --- libspu/mpc/cheetah/arith/cheetah_mul.cc | 167 ++++++++++++++++++ libspu/mpc/cheetah/arith/cheetah_mul.h | 13 ++ libspu/mpc/cheetah/arith/cheetah_mul_test.cc | 33 ++++ libspu/mpc/cheetah/arith/simd_mul_prot.cc | 60 +++++++ libspu/mpc/cheetah/arith/simd_mul_prot.h | 18 +- .../mpc/cheetah/arith/simd_mul_prot_test.cc | 20 +-- libspu/mpc/cheetah/arithmetic.cc | 44 ++++- 7 files changed, 333 insertions(+), 22 deletions(-) diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.cc b/libspu/mpc/cheetah/arith/cheetah_mul.cc index 072b2978..779041a9 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.cc +++ b/libspu/mpc/cheetah/arith/cheetah_mul.cc @@ -114,6 +114,10 @@ struct CheetahMul::Impl : public EnableCPRNG { NdArrayRef MulOLE(const NdArrayRef &shr, yacl::link::Context *conn, bool evaluator, uint32_t msg_width_hint); + NdArrayRef MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr, + yacl::link::Context *conn, bool evaluator, + uint32_t msg_width_hint); + protected: void LocalExpandSEALContexts(size_t target); @@ -167,6 +171,16 @@ struct CheetahMul::Impl : public EnableCPRNG { absl::Span rnd_mask, yacl::link::Context *conn = nullptr); + // Enc(x0) * y1 + Enc(y0) * x1 + rand_mask + void FMAThenResponse(FieldType field, int64_t num_elts, + const Options &options, + absl::Span ciphers_x0, + absl::Span ciphers_y0, + absl::Span plains_x1, + absl::Span plains_y1, + absl::Span rnd_mask, + yacl::link::Context *conn = nullptr); + void PrepareRandomMask(FieldType field, int64_t size, const Options &options, std::vector &mask); @@ -386,6 +400,79 @@ NdArrayRef CheetahMul::Impl::MulOLE(const NdArrayRef &shr, return DecryptArray(field, numel, options, recv_ct).reshape(shr.shape()); } +NdArrayRef CheetahMul::Impl::MulShare(const NdArrayRef &xshr, + const NdArrayRef &yshr, + yacl::link::Context *conn, bool evaluator, + uint32_t msg_width_hint) { + if (conn == nullptr) { + conn = lctx_.get(); + } + + auto eltype = xshr.eltype(); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(yshr.eltype().isa(), "must be ring_type, got={}", + yshr.eltype()); + SPU_ENFORCE(xshr.numel() > 0); + SPU_ENFORCE_EQ(xshr.shape(), yshr.shape()); + + auto field = eltype.as()->field(); + Options options; + options.ring_bitlen = SizeOf(field) * 8; + options.msg_bitlen = + msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint; + SPU_ENFORCE(options.msg_bitlen > 0 && + options.msg_bitlen <= options.ring_bitlen); + LazyExpandSEALContexts(options, conn); + LazyInitModSwitchHelper(options); + + size_t numel = xshr.numel(); + int nxt_rank = conn->NextRank(); + + // x0*y0 + + x1 * y1 + if (evaluator) { + std::vector encoded_x0; + std::vector encoded_y0; + EncodeArray(xshr, false, options, &encoded_x0); + EncodeArray(yshr, false, options, &encoded_y0); + + size_t payload_sze = encoded_x0.size(); + std::vector recv_ct_x1(payload_sze); + std::vector recv_ct_y1(payload_sze); + auto io_task = std::async(std::launch::async, [&]() { + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct_x1[idx] = conn->Recv(nxt_rank, ""); + } + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct_y1[idx] = conn->Recv(nxt_rank, ""); + } + }); + + std::vector random_share_mask; + PrepareRandomMask(field, xshr.numel(), options, random_share_mask); + + // wait for IO + io_task.get(); + FMAThenResponse(field, numel, options, recv_ct_x1, recv_ct_y1, encoded_x0, + encoded_y0, absl::MakeConstSpan(random_share_mask), conn); + // convert x \in [0, P) to [0, 2^k) by round(2^k*x/P) + auto &ms_helper = ms_helpers_.find(options)->second; + auto out = ms_helper.ModulusDownRNS(field, xshr.shape(), random_share_mask) + .reshape(xshr.shape()); + ring_add_(out, ring_mul(xshr, yshr)); + return out; + } + + size_t payload_sze = EncryptArrayThenSend(xshr, options, conn); + (void)EncryptArrayThenSend(yshr, options, conn); + std::vector recv_ct(payload_sze); + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct[idx] = conn->Recv(nxt_rank, ""); + } + auto out = DecryptArray(field, numel, options, recv_ct).reshape(xshr.shape()); + ring_add_(out, ring_mul(xshr, yshr)); + return out; +} + size_t CheetahMul::Impl::EncryptArrayThenSend(const NdArrayRef &array, const Options &options, yacl::link::Context *conn) { @@ -573,6 +660,72 @@ void CheetahMul::Impl::MulThenResponse(FieldType, int64_t num_elts, } } +void CheetahMul::Impl::FMAThenResponse( + FieldType, int64_t num_elts, const Options &options, + absl::Span ciphers_x0, + absl::Span ciphers_y0, + absl::Span plains_x1, absl::Span plains_y1, + absl::Span rnd_mask, yacl::link::Context *conn) { + SPU_ENFORCE(!ciphers_x0.empty(), "CheetahMul: empty cipher"); + SPU_ENFORCE(!ciphers_y0.empty(), "CheetahMul: empty cipher"); + SPU_ENFORCE_EQ(ciphers_x0.size(), ciphers_y0.size()); + SPU_ENFORCE_EQ(plains_x1.size(), ciphers_x0.size(), + "CheetahMul: ct/pt size mismatch"); + SPU_ENFORCE_EQ(plains_y1.size(), ciphers_y0.size(), + "CheetahMul: ct/pt size mismatch"); + + const int64_t num_splits = CeilDiv(num_elts, num_slots()); + const int64_t num_seal_ctx = WorkingContextSize(options); + const int64_t num_ciphers = num_seal_ctx * num_splits; + SPU_ENFORCE(ciphers_x0.size() == (size_t)num_ciphers, + "CheetahMul : expect {} != {}", num_ciphers, ciphers_x0.size()); + SPU_ENFORCE(rnd_mask.size() == (size_t)num_elts * num_seal_ctx, + "CheetahMul: rnd_mask size mismatch"); + + std::vector response(num_ciphers); + yacl::parallel_for(0, num_ciphers, [&](int64_t job_bgn, int64_t job_end) { + RLWECt ct_x; + RLWECt ct_y; + std::vector u64tmp(num_slots(), 0); + for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) { + int64_t cntxt_id = job_id / num_splits; + int64_t split_id = job_id % num_splits; + + int64_t slice_bgn = split_id * num_slots(); + int64_t slice_n = std::min(num_slots(), num_elts - slice_bgn); + // offset by context id + slice_bgn += cntxt_id * num_elts; + + DecodeSEALObject(ciphers_x0[job_id], seal_cntxts_[cntxt_id], &ct_x); + DecodeSEALObject(ciphers_y0[job_id], seal_cntxts_[cntxt_id], &ct_y); + + // ct_x <- Re-randomize(ct_x * pt_y + ct_y * pt_x) - random_mask + simd_mul_instances_[cntxt_id]->FMAThenReshareInplace( + {&ct_x, 1}, {&ct_y, 1}, plains_y1.subspan(job_id, 1), + plains_x1.subspan(job_id, 1), rnd_mask.subspan(slice_bgn, slice_n), + *peer_pub_key_, seal_cntxts_[cntxt_id]); + + response[job_id] = EncodeSEALObject(ct_x); + } + }); + + if (conn == nullptr) { + conn = lctx_.get(); + } + + int nxt_rank = conn->NextRank(); + for (int64_t i = 0; i < num_ciphers; i += kCtAsyncParallel) { + int64_t this_batch = std::min(num_ciphers - i, kCtAsyncParallel); + conn->Send(nxt_rank, response[i], + fmt::format("FMAThenResponse ct[{}] to rank{}", i, nxt_rank)); + for (int64_t j = 1; j < this_batch; ++j) { + conn->SendAsync( + nxt_rank, response[i + j], + fmt::format("FMAThenResponse ct[{}] to rank{}", i + j, nxt_rank)); + } + } +} + NdArrayRef CheetahMul::Impl::DecryptArray( FieldType field, int64_t size, const Options &options, const std::vector &ct_array) { @@ -625,6 +778,20 @@ size_t CheetahMul::OLEBatchSize() const { return impl_->OLEBatchSize(); } +NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr, + yacl::link::Context *conn, bool is_evaluator, + uint32_t msg_width_hint) { + SPU_ENFORCE(impl_ != nullptr); + SPU_ENFORCE(conn != nullptr); + return impl_->MulShare(xshr, yshr, conn, is_evaluator, msg_width_hint); +} + +NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr, + bool is_evaluator, uint32_t msg_width_hint) { + SPU_ENFORCE(impl_ != nullptr); + return impl_->MulShare(xshr, yshr, nullptr, is_evaluator, msg_width_hint); +} + NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, yacl::link::Context *conn, bool is_evaluator, uint32_t msg_width_hint) { SPU_ENFORCE(impl_ != nullptr); diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.h b/libspu/mpc/cheetah/arith/cheetah_mul.h index e687477a..304a7871 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.h +++ b/libspu/mpc/cheetah/arith/cheetah_mul.h @@ -44,14 +44,27 @@ class CheetahMul { void LazyInitKeys(FieldType field, uint32_t msg_width_hint = 0); + // x, y => [x*y] for two private inputs // NOTE: make sure to call InitKeys first NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn, bool is_evaluator, uint32_t msg_width_hint = 0); + // x, y => [x*y] for two private inputs // NOTE: make sure to call InitKeys first NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator, uint32_t msg_width_hint = 0); + // [x], [y] => [x*y] for two shares + // NOTE: make sure to call InitKeys first + NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y, + yacl::link::Context* conn, bool is_evaluator, + uint32_t msg_width_hint = 0); + + // [x], [y] => [x*y] for two shares + // NOTE: make sure to call InitKeys first + NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y, + bool is_evaluator, uint32_t msg_width_hint = 0); + int Rank() const; size_t OLEBatchSize() const; diff --git a/libspu/mpc/cheetah/arith/cheetah_mul_test.cc b/libspu/mpc/cheetah/arith/cheetah_mul_test.cc index d22040c1..4cd9635e 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_mul_test.cc @@ -207,4 +207,37 @@ TEST_P(CheetahMulTest, MixedRingSizeMul) { EXPECT_TRUE(ring_all_equal(expected2, computed2, kMaxDiff)); } +TEST_P(CheetahMulTest, MulShare) { + size_t kWorldSize = 2; + auto field = std::get<0>(GetParam()); + int64_t n = std::get<1>(GetParam()); + bool allow_approx = std::get<2>(GetParam()); + + auto a_bits = ring_rand(field, {n}); + auto b_bits = ring_rand(field, {n}); + + std::vector a_shr(kWorldSize); + std::vector b_shr(kWorldSize); + a_shr[0] = ring_rand(field, {n}); + b_shr[0] = ring_rand(field, {n}); + a_shr[1] = ring_sub(a_bits, a_shr[0]); + b_shr[1] = ring_sub(b_bits, b_shr[0]); + + std::vector result(kWorldSize); + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + int rank = lctx->Rank(); + // (a0 + a1) * (b0 + b1) + // a0*b0 + a0*b1 + a1*b0 + a1*b1 + auto mul = std::make_shared(lctx, allow_approx); + + result[rank] = mul->MulShare(a_shr[rank], b_shr[rank], rank == 0); + }); + + auto expected = ring_mul(a_bits, b_bits); + auto computed = ring_add(result[0], result[1]); + + const int64_t kMaxDiff = allow_approx ? 1 : 0; + EXPECT_TRUE(ring_all_equal(expected, computed, kMaxDiff)); +} + } // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/arith/simd_mul_prot.cc b/libspu/mpc/cheetah/arith/simd_mul_prot.cc index f347b87b..74d0b291 100644 --- a/libspu/mpc/cheetah/arith/simd_mul_prot.cc +++ b/libspu/mpc/cheetah/arith/simd_mul_prot.cc @@ -222,6 +222,66 @@ void SIMDMulProt::MulThenReshareInplace(absl::Span ct, } } +// Compute ct0 * pt1 + ct1 * pt1 - mask mod p +void SIMDMulProt::FMAThenReshareInplace(absl::Span ct0, + absl::Span ct1, + absl::Span pt0, + absl::Span pt1, + absl::Span share_mask, + const RLWEPublicKey &public_key, + const seal::SEALContext &context) { + SPU_ENFORCE_EQ(ct0.size(), ct1.size()); + SPU_ENFORCE_EQ(pt0.size(), pt1.size()); + SPU_ENFORCE_EQ(ct0.size(), pt0.size()); + SPU_ENFORCE_EQ(CeilDiv(share_mask.size(), (size_t)simd_lane_), ct0.size()); + + seal::Evaluator evaluator(context); + RLWECt zero_enc; + RLWEPt rnd; + + constexpr int kMarginBitsForDec = 10; + seal::parms_id_type final_level_id = context.last_parms_id(); + while (final_level_id != context.first_parms_id()) { + auto cntxt = context.get_context_data(final_level_id); + if (cntxt->total_coeff_modulus_bit_count() >= + kMarginBitsForDec + cntxt->parms().plain_modulus().bit_count()) { + break; + } + final_level_id = cntxt->prev_context_data()->parms_id(); + } + + RLWECt tmp_ct; + for (size_t i = 0; i < ct0.size(); ++i) { + // 1. Ct-Pt Mul + evaluator.multiply_plain_inplace(ct0[i], pt0[i]); + evaluator.multiply_plain(ct1[i], pt1[i], tmp_ct); + evaluator.add_inplace(ct0[i], tmp_ct); + + // 2. Noise flooding + NoiseFloodInplace(ct0[i], context); + + // 3. Drop some modulus for a smaller communication + evaluator.mod_switch_to_inplace(ct0[i], final_level_id); + + // 4. Re-randomize via adding enc(0) + seal::util::encrypt_zero_asymmetric(public_key, context, ct0[i].parms_id(), + ct0[i].is_ntt_form(), zero_enc); + evaluator.add_inplace(ct0[i], zero_enc); + + // 5. Additive share + size_t slice_bgn = i * simd_lane_; + size_t slice_n = + std::min((size_t)simd_lane_, share_mask.size() - slice_bgn); + EncodeSingle(share_mask.subspan(slice_bgn, slice_n), rnd); + evaluator.sub_plain_inplace(ct0[i], rnd); + + // 6. Truncate for smaller communication + if (ct0[i].coeff_modulus_size() == 1) { + TruncateBFVForDecryption(ct0[i], context); + } + } +} + void SIMDMulProt::NoiseFloodInplace(RLWECt &ct, const seal::SEALContext &context) { SPU_ENFORCE(seal::is_metadata_valid_for(ct, context)); diff --git a/libspu/mpc/cheetah/arith/simd_mul_prot.h b/libspu/mpc/cheetah/arith/simd_mul_prot.h index 22d79f02..c18d7da5 100644 --- a/libspu/mpc/cheetah/arith/simd_mul_prot.h +++ b/libspu/mpc/cheetah/arith/simd_mul_prot.h @@ -55,11 +55,19 @@ class SIMDMulProt : public EnableCPRNG { const RLWEPublicKey& public_key, const seal::SEALContext& context); - void MulThenReshareInplaceOneBit(absl::Span ct, - absl::Span pt, - absl::Span share_mask, - const RLWEPublicKey& public_key, - const seal::SEALContext& context); + // ct0 * pt0 + ct1 * pt1 + mask + void FMAThenReshareInplace(absl::Span ct0, + absl::Span ct1, + absl::Span pt0, + absl::Span pt1, + absl::Span share_mask, + const RLWEPublicKey& public_key, + const seal::SEALContext& context); + + [[deprecated]] void MulThenReshareInplaceOneBit( + absl::Span ct, absl::Span pt, + absl::Span share_mask, const RLWEPublicKey& public_key, + const seal::SEALContext& context); inline int64_t SIMDLane() const { return simd_lane_; } diff --git a/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc b/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc index 006baaee..3c15ff57 100644 --- a/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc +++ b/libspu/mpc/cheetah/arith/simd_mul_prot_test.cc @@ -87,7 +87,7 @@ class SIMDMulTest : public ::testing::TestWithParam, public EnableCPRNG { }; INSTANTIATE_TEST_SUITE_P( - Cheetah, SIMDMulTest, testing::Values(true, false), + Cheetah, SIMDMulTest, testing::Values(true), [](const testing::TestParamInfo &p) { return fmt::format("{}", p.param ? "NoiseFlood" : "Approx"); }); @@ -116,20 +116,10 @@ TEST_P(SIMDMulTest, Basic) { simd_mul_prot_->SymEncrypt(encode_b, *rlwe_sk_, *context_, false, absl::MakeSpan(encrypt_b)); - if (GetParam()) { - RandomPlain(absl::MakeSpan(out_a)); - simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a, - absl::MakeConstSpan(out_a), - *rlwe_pk_, *context_); - } else { - simd_mul_prot_->MulThenReshareInplaceOneBit( - absl::MakeSpan(encrypt_b), encode_a, absl::MakeSpan(out_a), *rlwe_pk_, - *context_); - } - if (rep == 0) { - printf("rep ct.L %zd\n", encrypt_b[0].coeff_modulus_size()); - } - + RandomPlain(absl::MakeSpan(out_a)); + simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a, + absl::MakeConstSpan(out_a), *rlwe_pk_, + *context_); auto _out_b = absl::MakeSpan(out_b); for (size_t i = 0; i < num_pt; ++i) { seal::Plaintext pt; diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index c1765b46..1e2b32b9 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -245,7 +245,7 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, int64_t batch_sze = ctx->getState()->get()->OLEBatchSize(); int64_t numel = x.numel(); - if (numel >= batch_sze) { + if (numel >= 2 * batch_sze) { return mulDirectly(ctx, x, y); } return mulWithBeaver(ctx, x, y); @@ -326,6 +326,46 @@ NdArrayRef MulAA::mulWithBeaver(KernelEvalContext* ctx, const NdArrayRef& x, return z.as(x.eltype()); } +#if 1 +NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + // Compute (x0 + x1) * (y0+ y1) + auto* comm = ctx->getState(); + auto* mul_prot = ctx->getState()->get(); + mul_prot->LazyInitKeys(x.eltype().as()->field()); + + auto fx = x.reshape({x.numel()}); + auto fy = y.reshape({y.numel()}); + const int64_t n = fx.numel(); + const int64_t nhalf = n / 2; + const int rank = comm->getRank(); + + // For long vectors, split into two subtasks. + auto dupx = ctx->getState()->duplx(); + std::future task = std::async(std::launch::async, [&] { + return mul_prot->MulShare(fx.slice({nhalf}, {n}, {1}), + fy.slice({nhalf}, {n}, {1}), dupx.get(), + /*evaluator*/ rank == 0); + }); + + std::vector out_slices(2); + out_slices[0] = + mul_prot->MulShare(fx.slice({0}, {nhalf}, {1}), + fy.slice({0}, {nhalf}, {1}), /*evaluato*/ rank != 0); + out_slices[1] = task.get(); + + NdArrayRef out(out_slices[0].eltype(), x.shape()); + int64_t offset = 0; + for (auto& out_slice : out_slices) { + std::memcpy(out.data() + offset, out_slice.data(), + out_slice.numel() * out.elsize()); + offset += out_slice.numel() * out.elsize(); + } + return out; +} +#else +// Old code for MulAA using two OLEs which commnuicate about 30% more than the +// above version. NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { // (x0 + x1) * (y0+ y1) @@ -335,7 +375,6 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, mul_prot->LazyInitKeys(x.eltype().as()->field()); const int rank = comm->getRank(); - // auto fy = y.reshape({y.numel()}); auto dupx = ctx->getState()->duplx(); std::future task = std::async(std::launch::async, [&] { @@ -355,6 +394,7 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef x0y1 = task.get(); return ring_add(x0y1, ring_add(x1y0, ring_mul(x, y))).as(x.eltype()); } +#endif NdArrayRef MatMulVVS::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const {