Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHEETAH] Optimize the MulAA communication. #850

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ struct CheetahMul::Impl : public EnableCPRNG {
NdArrayRef MulOLE(const NdArrayRef &shr, yacl::link::Context *conn,
bool evaluator, uint32_t msg_width_hint);

NdArrayRef MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr,
yacl::link::Context *conn, bool evaluator,
uint32_t msg_width_hint);

protected:
void LocalExpandSEALContexts(size_t target);

Expand Down Expand Up @@ -167,6 +171,16 @@ struct CheetahMul::Impl : public EnableCPRNG {
absl::Span<const uint64_t> rnd_mask,
yacl::link::Context *conn = nullptr);

// Enc(x0) * y1 + Enc(y0) * x1 + rand_mask
void FMAThenResponse(FieldType field, int64_t num_elts,
const Options &options,
absl::Span<const yacl::Buffer> ciphers_x0,
absl::Span<const yacl::Buffer> ciphers_y0,
absl::Span<const RLWEPt> plains_x1,
absl::Span<const RLWEPt> plains_y1,
absl::Span<const uint64_t> rnd_mask,
yacl::link::Context *conn = nullptr);

void PrepareRandomMask(FieldType field, int64_t size, const Options &options,
std::vector<uint64_t> &mask);

Expand Down Expand Up @@ -386,6 +400,79 @@ NdArrayRef CheetahMul::Impl::MulOLE(const NdArrayRef &shr,
return DecryptArray(field, numel, options, recv_ct).reshape(shr.shape());
}

NdArrayRef CheetahMul::Impl::MulShare(const NdArrayRef &xshr,
const NdArrayRef &yshr,
yacl::link::Context *conn, bool evaluator,
uint32_t msg_width_hint) {
if (conn == nullptr) {
conn = lctx_.get();
}

auto eltype = xshr.eltype();
SPU_ENFORCE(eltype.isa<Ring2k>(), "must be ring_type, got={}", eltype);
SPU_ENFORCE(yshr.eltype().isa<Ring2k>(), "must be ring_type, got={}",
yshr.eltype());
SPU_ENFORCE(xshr.numel() > 0);
SPU_ENFORCE_EQ(xshr.shape(), yshr.shape());

auto field = eltype.as<Ring2k>()->field();
Options options;
options.ring_bitlen = SizeOf(field) * 8;
options.msg_bitlen =
msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint;
SPU_ENFORCE(options.msg_bitlen > 0 &&
options.msg_bitlen <= options.ring_bitlen);
LazyExpandSEALContexts(options, conn);
LazyInitModSwitchHelper(options);

size_t numel = xshr.numel();
int nxt_rank = conn->NextRank();

// x0*y0 + <x0 + y1 + x1 * y0> + x1 * y1
if (evaluator) {
std::vector<RLWEPt> encoded_x0;
std::vector<RLWEPt> encoded_y0;
EncodeArray(xshr, false, options, &encoded_x0);
EncodeArray(yshr, false, options, &encoded_y0);

size_t payload_sze = encoded_x0.size();
std::vector<yacl::Buffer> recv_ct_x1(payload_sze);
std::vector<yacl::Buffer> recv_ct_y1(payload_sze);
auto io_task = std::async(std::launch::async, [&]() {
for (size_t idx = 0; idx < payload_sze; ++idx) {
recv_ct_x1[idx] = conn->Recv(nxt_rank, "");
}
for (size_t idx = 0; idx < payload_sze; ++idx) {
recv_ct_y1[idx] = conn->Recv(nxt_rank, "");
}
});

std::vector<uint64_t> random_share_mask;
PrepareRandomMask(field, xshr.numel(), options, random_share_mask);

// wait for IO
io_task.get();
FMAThenResponse(field, numel, options, recv_ct_x1, recv_ct_y1, encoded_x0,
encoded_y0, absl::MakeConstSpan(random_share_mask), conn);
// convert x \in [0, P) to [0, 2^k) by round(2^k*x/P)
auto &ms_helper = ms_helpers_.find(options)->second;
auto out = ms_helper.ModulusDownRNS(field, xshr.shape(), random_share_mask)
.reshape(xshr.shape());
ring_add_(out, ring_mul(xshr, yshr));
return out;
}

size_t payload_sze = EncryptArrayThenSend(xshr, options, conn);
(void)EncryptArrayThenSend(yshr, options, conn);
std::vector<yacl::Buffer> recv_ct(payload_sze);
for (size_t idx = 0; idx < payload_sze; ++idx) {
recv_ct[idx] = conn->Recv(nxt_rank, "");
}
auto out = DecryptArray(field, numel, options, recv_ct).reshape(xshr.shape());
ring_add_(out, ring_mul(xshr, yshr));
return out;
}

size_t CheetahMul::Impl::EncryptArrayThenSend(const NdArrayRef &array,
const Options &options,
yacl::link::Context *conn) {
Expand Down Expand Up @@ -573,6 +660,72 @@ void CheetahMul::Impl::MulThenResponse(FieldType, int64_t num_elts,
}
}

void CheetahMul::Impl::FMAThenResponse(
FieldType, int64_t num_elts, const Options &options,
absl::Span<const yacl::Buffer> ciphers_x0,
absl::Span<const yacl::Buffer> ciphers_y0,
absl::Span<const RLWEPt> plains_x1, absl::Span<const RLWEPt> plains_y1,
absl::Span<const uint64_t> rnd_mask, yacl::link::Context *conn) {
SPU_ENFORCE(!ciphers_x0.empty(), "CheetahMul: empty cipher");
SPU_ENFORCE(!ciphers_y0.empty(), "CheetahMul: empty cipher");
SPU_ENFORCE_EQ(ciphers_x0.size(), ciphers_y0.size());
SPU_ENFORCE_EQ(plains_x1.size(), ciphers_x0.size(),
"CheetahMul: ct/pt size mismatch");
SPU_ENFORCE_EQ(plains_y1.size(), ciphers_y0.size(),
"CheetahMul: ct/pt size mismatch");

const int64_t num_splits = CeilDiv(num_elts, num_slots());
const int64_t num_seal_ctx = WorkingContextSize(options);
const int64_t num_ciphers = num_seal_ctx * num_splits;
SPU_ENFORCE(ciphers_x0.size() == (size_t)num_ciphers,
"CheetahMul : expect {} != {}", num_ciphers, ciphers_x0.size());
SPU_ENFORCE(rnd_mask.size() == (size_t)num_elts * num_seal_ctx,
"CheetahMul: rnd_mask size mismatch");

std::vector<yacl::Buffer> response(num_ciphers);
yacl::parallel_for(0, num_ciphers, [&](int64_t job_bgn, int64_t job_end) {
RLWECt ct_x;
RLWECt ct_y;
std::vector<uint64_t> u64tmp(num_slots(), 0);
for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) {
int64_t cntxt_id = job_id / num_splits;
int64_t split_id = job_id % num_splits;

int64_t slice_bgn = split_id * num_slots();
int64_t slice_n = std::min(num_slots(), num_elts - slice_bgn);
// offset by context id
slice_bgn += cntxt_id * num_elts;

DecodeSEALObject(ciphers_x0[job_id], seal_cntxts_[cntxt_id], &ct_x);
DecodeSEALObject(ciphers_y0[job_id], seal_cntxts_[cntxt_id], &ct_y);

// ct_x <- Re-randomize(ct_x * pt_y + ct_y * pt_x) - random_mask
simd_mul_instances_[cntxt_id]->FMAThenReshareInplace(
{&ct_x, 1}, {&ct_y, 1}, plains_y1.subspan(job_id, 1),
plains_x1.subspan(job_id, 1), rnd_mask.subspan(slice_bgn, slice_n),
*peer_pub_key_, seal_cntxts_[cntxt_id]);

response[job_id] = EncodeSEALObject(ct_x);
}
});

if (conn == nullptr) {
conn = lctx_.get();
}

int nxt_rank = conn->NextRank();
for (int64_t i = 0; i < num_ciphers; i += kCtAsyncParallel) {
int64_t this_batch = std::min(num_ciphers - i, kCtAsyncParallel);
conn->Send(nxt_rank, response[i],
fmt::format("FMAThenResponse ct[{}] to rank{}", i, nxt_rank));
for (int64_t j = 1; j < this_batch; ++j) {
conn->SendAsync(
nxt_rank, response[i + j],
fmt::format("FMAThenResponse ct[{}] to rank{}", i + j, nxt_rank));
}
}
}

NdArrayRef CheetahMul::Impl::DecryptArray(
FieldType field, int64_t size, const Options &options,
const std::vector<yacl::Buffer> &ct_array) {
Expand Down Expand Up @@ -625,6 +778,20 @@ size_t CheetahMul::OLEBatchSize() const {
return impl_->OLEBatchSize();
}

NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr,
yacl::link::Context *conn, bool is_evaluator,
uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
SPU_ENFORCE(conn != nullptr);
return impl_->MulShare(xshr, yshr, conn, is_evaluator, msg_width_hint);
}

NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr,
bool is_evaluator, uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
return impl_->MulShare(xshr, yshr, nullptr, is_evaluator, msg_width_hint);
}

NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, yacl::link::Context *conn,
bool is_evaluator, uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
Expand Down
13 changes: 13 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,27 @@ class CheetahMul {

void LazyInitKeys(FieldType field, uint32_t msg_width_hint = 0);

// x, y => [x*y] for two private inputs
// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn,
bool is_evaluator, uint32_t msg_width_hint = 0);

// x, y => [x*y] for two private inputs
// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator,
uint32_t msg_width_hint = 0);

// [x], [y] => [x*y] for two shares
// NOTE: make sure to call InitKeys first
NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y,
yacl::link::Context* conn, bool is_evaluator,
uint32_t msg_width_hint = 0);

// [x], [y] => [x*y] for two shares
// NOTE: make sure to call InitKeys first
NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y,
bool is_evaluator, uint32_t msg_width_hint = 0);

int Rank() const;

size_t OLEBatchSize() const;
Expand Down
33 changes: 33 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,37 @@ TEST_P(CheetahMulTest, MixedRingSizeMul) {
EXPECT_TRUE(ring_all_equal(expected2, computed2, kMaxDiff));
}

TEST_P(CheetahMulTest, MulShare) {
size_t kWorldSize = 2;
auto field = std::get<0>(GetParam());
int64_t n = std::get<1>(GetParam());
bool allow_approx = std::get<2>(GetParam());

auto a_bits = ring_rand(field, {n});
auto b_bits = ring_rand(field, {n});

std::vector<NdArrayRef> a_shr(kWorldSize);
std::vector<NdArrayRef> b_shr(kWorldSize);
a_shr[0] = ring_rand(field, {n});
b_shr[0] = ring_rand(field, {n});
a_shr[1] = ring_sub(a_bits, a_shr[0]);
b_shr[1] = ring_sub(b_bits, b_shr[0]);

std::vector<NdArrayRef> result(kWorldSize);
utils::simulate(kWorldSize, [&](std::shared_ptr<yacl::link::Context> lctx) {
int rank = lctx->Rank();
// (a0 + a1) * (b0 + b1)
// a0*b0 + a0*b1 + a1*b0 + a1*b1
auto mul = std::make_shared<CheetahMul>(lctx, allow_approx);

result[rank] = mul->MulShare(a_shr[rank], b_shr[rank], rank == 0);
});

auto expected = ring_mul(a_bits, b_bits);
auto computed = ring_add(result[0], result[1]);

const int64_t kMaxDiff = allow_approx ? 1 : 0;
EXPECT_TRUE(ring_all_equal(expected, computed, kMaxDiff));
}

} // namespace spu::mpc::cheetah::test
60 changes: 60 additions & 0 deletions libspu/mpc/cheetah/arith/simd_mul_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,66 @@ void SIMDMulProt::MulThenReshareInplace(absl::Span<RLWECt> ct,
}
}

// Compute ct0 * pt1 + ct1 * pt1 - mask mod p
void SIMDMulProt::FMAThenReshareInplace(absl::Span<RLWECt> ct0,
absl::Span<const RLWECt> ct1,
absl::Span<const RLWEPt> pt0,
absl::Span<const RLWEPt> pt1,
absl::Span<const uint64_t> share_mask,
const RLWEPublicKey &public_key,
const seal::SEALContext &context) {
SPU_ENFORCE_EQ(ct0.size(), ct1.size());
SPU_ENFORCE_EQ(pt0.size(), pt1.size());
SPU_ENFORCE_EQ(ct0.size(), pt0.size());
SPU_ENFORCE_EQ(CeilDiv(share_mask.size(), (size_t)simd_lane_), ct0.size());

seal::Evaluator evaluator(context);
RLWECt zero_enc;
RLWEPt rnd;

constexpr int kMarginBitsForDec = 10;
seal::parms_id_type final_level_id = context.last_parms_id();
while (final_level_id != context.first_parms_id()) {
auto cntxt = context.get_context_data(final_level_id);
if (cntxt->total_coeff_modulus_bit_count() >=
kMarginBitsForDec + cntxt->parms().plain_modulus().bit_count()) {
break;
}
final_level_id = cntxt->prev_context_data()->parms_id();
}

RLWECt tmp_ct;
for (size_t i = 0; i < ct0.size(); ++i) {
// 1. Ct-Pt Mul
evaluator.multiply_plain_inplace(ct0[i], pt0[i]);
evaluator.multiply_plain(ct1[i], pt1[i], tmp_ct);
evaluator.add_inplace(ct0[i], tmp_ct);

// 2. Noise flooding
NoiseFloodInplace(ct0[i], context);

// 3. Drop some modulus for a smaller communication
evaluator.mod_switch_to_inplace(ct0[i], final_level_id);

// 4. Re-randomize via adding enc(0)
seal::util::encrypt_zero_asymmetric(public_key, context, ct0[i].parms_id(),
ct0[i].is_ntt_form(), zero_enc);
evaluator.add_inplace(ct0[i], zero_enc);

// 5. Additive share
size_t slice_bgn = i * simd_lane_;
size_t slice_n =
std::min((size_t)simd_lane_, share_mask.size() - slice_bgn);
EncodeSingle(share_mask.subspan(slice_bgn, slice_n), rnd);
evaluator.sub_plain_inplace(ct0[i], rnd);

// 6. Truncate for smaller communication
if (ct0[i].coeff_modulus_size() == 1) {
TruncateBFVForDecryption(ct0[i], context);
}
}
}

void SIMDMulProt::NoiseFloodInplace(RLWECt &ct,
const seal::SEALContext &context) {
SPU_ENFORCE(seal::is_metadata_valid_for(ct, context));
Expand Down
18 changes: 13 additions & 5 deletions libspu/mpc/cheetah/arith/simd_mul_prot.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ class SIMDMulProt : public EnableCPRNG {
const RLWEPublicKey& public_key,
const seal::SEALContext& context);

void MulThenReshareInplaceOneBit(absl::Span<RLWECt> ct,
absl::Span<const RLWEPt> pt,
absl::Span<uint64_t> share_mask,
const RLWEPublicKey& public_key,
const seal::SEALContext& context);
// ct0 * pt0 + ct1 * pt1 + mask
void FMAThenReshareInplace(absl::Span<RLWECt> ct0,
absl::Span<const RLWECt> ct1,
absl::Span<const RLWEPt> pt0,
absl::Span<const RLWEPt> pt1,
absl::Span<const uint64_t> share_mask,
const RLWEPublicKey& public_key,
const seal::SEALContext& context);

[[deprecated]] void MulThenReshareInplaceOneBit(
absl::Span<RLWECt> ct, absl::Span<const RLWEPt> pt,
absl::Span<uint64_t> share_mask, const RLWEPublicKey& public_key,
const seal::SEALContext& context);

inline int64_t SIMDLane() const { return simd_lane_; }

Expand Down
20 changes: 5 additions & 15 deletions libspu/mpc/cheetah/arith/simd_mul_prot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SIMDMulTest : public ::testing::TestWithParam<bool>, public EnableCPRNG {
};

INSTANTIATE_TEST_SUITE_P(
Cheetah, SIMDMulTest, testing::Values(true, false),
Cheetah, SIMDMulTest, testing::Values(true),
[](const testing::TestParamInfo<SIMDMulTest::ParamType> &p) {
return fmt::format("{}", p.param ? "NoiseFlood" : "Approx");
});
Expand Down Expand Up @@ -116,20 +116,10 @@ TEST_P(SIMDMulTest, Basic) {
simd_mul_prot_->SymEncrypt(encode_b, *rlwe_sk_, *context_, false,
absl::MakeSpan(encrypt_b));

if (GetParam()) {
RandomPlain(absl::MakeSpan(out_a));
simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a,
absl::MakeConstSpan(out_a),
*rlwe_pk_, *context_);
} else {
simd_mul_prot_->MulThenReshareInplaceOneBit(
absl::MakeSpan(encrypt_b), encode_a, absl::MakeSpan(out_a), *rlwe_pk_,
*context_);
}
if (rep == 0) {
printf("rep ct.L %zd\n", encrypt_b[0].coeff_modulus_size());
}

RandomPlain(absl::MakeSpan(out_a));
simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a,
absl::MakeConstSpan(out_a), *rlwe_pk_,
*context_);
auto _out_b = absl::MakeSpan(out_b);
for (size_t i = 0; i < num_pt; ++i) {
seal::Plaintext pt;
Expand Down
Loading
Loading