From 1c5d8e5bb34b104103e979eae377c89c9cfd7d69 Mon Sep 17 00:00:00 2001 From: Jimmy MA Date: Wed, 16 Oct 2024 12:50:38 +0800 Subject: [PATCH] repo-sync-2024-10-15T16:59:25+0800 (#888) # Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility: --- CHANGELOG.md | 1 + bazel/repositories.bzl | 6 +- libspu/compiler/front_end/hlo_importer.h | 2 + libspu/core/config.cc | 11 + libspu/core/type.h | 61 +++- libspu/core/type_test.cc | 20 ++ libspu/core/type_util.cc | 16 + libspu/core/type_util.h | 27 +- libspu/kernel/hal/BUILD.bazel | 1 + libspu/kernel/hal/fxp_approx.cc | 36 ++- libspu/kernel/hal/fxp_approx.h | 2 + libspu/kernel/hal/fxp_approx_test.cc | 28 +- libspu/kernel/hal/ring.cc | 15 +- libspu/kernel/hal/ring.h | 5 + libspu/mpc/aby3/oram.cc | 4 +- .../cheetah/nonlinear/compare_prot_test.cc | 24 +- libspu/mpc/common/BUILD.bazel | 2 + libspu/mpc/common/communicator.cc | 52 ++- libspu/mpc/common/communicator.h | 5 + libspu/mpc/common/prg_tensor.h | 29 +- libspu/mpc/semi2k/BUILD.bazel | 32 ++ libspu/mpc/semi2k/beaver/beaver_cache.h | 2 + .../mpc/semi2k/beaver/beaver_impl/BUILD.bazel | 3 + .../semi2k/beaver/beaver_impl/beaver_test.cc | 301 +++++++++++++++--- .../semi2k/beaver/beaver_impl/beaver_tfp.cc | 62 +++- .../semi2k/beaver/beaver_impl/beaver_tfp.h | 6 +- .../semi2k/beaver/beaver_impl/beaver_ttp.cc | 74 ++++- .../semi2k/beaver/beaver_impl/beaver_ttp.h | 6 +- .../beaver_impl/trusted_party/BUILD.bazel | 2 + .../trusted_party/trusted_party.cc | 72 ++++- .../beaver_impl/trusted_party/trusted_party.h | 2 + .../beaver_impl/ttp_server/beaver_server.cc | 26 +- .../beaver_impl/ttp_server/service.proto | 30 ++ libspu/mpc/semi2k/beaver/beaver_interface.h | 8 +- libspu/mpc/semi2k/exp.cc | 97 ++++++ libspu/mpc/semi2k/exp.h | 37 +++ libspu/mpc/semi2k/prime_utils.cc | 201 ++++++++++++ libspu/mpc/semi2k/prime_utils.h | 46 +++ libspu/mpc/semi2k/protocol.cc | 7 + libspu/mpc/semi2k/protocol_test.cc | 179 +++++++++++ libspu/mpc/utils/BUILD.bazel | 31 ++ libspu/mpc/utils/gfmp.h | 168 ++++++++++ libspu/mpc/utils/gfmp_ops.cc | 251 +++++++++++++++ libspu/mpc/utils/gfmp_ops.h | 45 +++ libspu/spu.proto | 20 ++ 45 files changed, 1962 insertions(+), 93 deletions(-) create mode 100644 libspu/mpc/semi2k/exp.cc create mode 100644 libspu/mpc/semi2k/exp.h create mode 100644 libspu/mpc/semi2k/prime_utils.cc create mode 100644 libspu/mpc/semi2k/prime_utils.h create mode 100644 libspu/mpc/utils/gfmp.h create mode 100644 libspu/mpc/utils/gfmp_ops.cc create mode 100644 libspu/mpc/utils/gfmp_ops.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 27502fd6..0d09208e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ > > please add your unreleased change here. +- [Improvement] Optimize exponential computation for semi2k (**experimental**) - [Feature] Add more send/recv actions profiling ## 20240716 diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 6ba0fd47..72fd415e 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -50,10 +50,10 @@ def _libpsi(): http_archive, name = "psi", urls = [ - "https://github.com/secretflow/psi/archive/refs/tags/v0.4.3.dev240919.tar.gz", + "https://github.com/secretflow/psi/archive/refs/tags/v0.5.0.dev241016.tar.gz", ], - strip_prefix = "psi-0.4.3.dev240919", - sha256 = "1ee34fbbd9a8f36dea8f7c45588a858e8c31f3a38e60e1fc67cb428ea79334e3", + strip_prefix = "psi-0.5.0.dev241016", + sha256 = "1672e4284f819c40e34c65b0d5b1dfe4cc959b81d6f63daef7b39f7eb8d742e2", ) def _rules_proto_grpc(): diff --git a/libspu/compiler/front_end/hlo_importer.h b/libspu/compiler/front_end/hlo_importer.h index 8e57b7f0..551f2755 100644 --- a/libspu/compiler/front_end/hlo_importer.h +++ b/libspu/compiler/front_end/hlo_importer.h @@ -28,7 +28,9 @@ class CompilationContext; class HloImporter final { public: + // clang-format off explicit HloImporter(CompilationContext *context) : context_(context) {}; + // clang-format on /// Load a xla module and returns a mlir-hlo module mlir::OwningOpRef diff --git a/libspu/core/config.cc b/libspu/core/config.cc index 7f63fc67..81ca2f9d 100644 --- a/libspu/core/config.cc +++ b/libspu/core/config.cc @@ -62,6 +62,17 @@ void populateRuntimeConfig(RuntimeConfig& cfg) { if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_DEFAULT) { cfg.set_fxp_exp_mode(RuntimeConfig::EXP_TAYLOR); } + if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_PRIME) { + // 0 offset is not supported + if (cfg.experimental_exp_prime_offset() == 0) { + // For FM128 default offset is 13 + if (cfg.field() == FieldType::FM128) { + cfg.set_experimental_exp_prime_offset(13); + } + // TODO: set defaults for other fields, currently only FM128 is + // supported + } + } if (cfg.fxp_exp_iters() == 0) { cfg.set_fxp_exp_iters(8); diff --git a/libspu/core/type.h b/libspu/core/type.h index 62ebca90..2bcf6819 100644 --- a/libspu/core/type.h +++ b/libspu/core/type.h @@ -43,6 +43,16 @@ class Ring2k { FieldType field() const { return field_; } }; +// This trait means the data is maintained in Galois prime field. +class Gfp { + protected: + uint128_t prime_{0}; + + public: + virtual ~Gfp() = default; + uint128_t p() const { return prime_; } +}; + // The public interface. // // The value of this type is public visible for parties. @@ -384,6 +394,54 @@ class RingTy : public TypeImpl { } }; +// Galois field type of Mersenne primes, e.g., 2^127-1 +class GfmpTy : public TypeImpl { + using Base = TypeImpl; + + protected: + size_t mersenne_prime_exp_; + + public: + using Base::Base; + explicit GfmpTy(FieldType field) { + field_ = field; + mersenne_prime_exp_ = GetMersennePrimeExp(field); + prime_ = (static_cast(1) << mersenne_prime_exp_) - 1; + } + + static std::string_view getStaticId() { return "Gfmp"; } + + size_t size() const override { + if (field_ == FT_INVALID) { + return 0; + } + return SizeOf(GetStorageType(field_)); + } + + size_t mp_exp() const { return mersenne_prime_exp_; } + + void fromString(std::string_view detail) override { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto mp_exp_str = detail.substr(comma + 1); + SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), + "parse failed from={}", detail); + mersenne_prime_exp_ = std::stoul(std::string(mp_exp_str)); + prime_ = (static_cast(1) << mersenne_prime_exp_) - 1; + } + + std::string toString() const override { + return fmt::format("{},{}", FieldType_Name(field()), mersenne_prime_exp_); + } + + bool equals(TypeObject const* other) const override { + auto const* derived_other = dynamic_cast(other); + SPU_ENFORCE(derived_other); + return field() == derived_other->field() && + mp_exp() == derived_other->mp_exp() && p() == derived_other->p(); + } +}; + class TypeContext final { public: using TypeCreateFn = @@ -395,7 +453,8 @@ class TypeContext final { public: TypeContext() { - addTypes(); // Base types that we need to register + addTypes(); // Base types that we need to register } template diff --git a/libspu/core/type_test.cc b/libspu/core/type_test.cc index 88107473..29510b52 100644 --- a/libspu/core/type_test.cc +++ b/libspu/core/type_test.cc @@ -125,4 +125,24 @@ TEST(TypeTest, RingTy) { EXPECT_EQ(Type::fromString(fm128.toString()), fm128); } +TEST(TypeTest, GfmpTy) { + Type gfmp31 = makeType(FM32); + EXPECT_EQ(gfmp31.size(), 4); + EXPECT_TRUE(gfmp31.isa()); + EXPECT_EQ(gfmp31.toString(), "Gfmp"); + EXPECT_EQ(Type::fromString(gfmp31.toString()), gfmp31); + + Type gfmp61 = makeType(FM64); + EXPECT_EQ(gfmp61.size(), 8); + EXPECT_TRUE(gfmp61.isa()); + EXPECT_EQ(gfmp61.toString(), "Gfmp"); + EXPECT_EQ(Type::fromString(gfmp61.toString()), gfmp61); + + Type gfmp127 = makeType(FM128); + EXPECT_EQ(gfmp127.size(), 16); + EXPECT_TRUE(gfmp127.isa()); + EXPECT_EQ(gfmp127.toString(), "Gfmp"); + EXPECT_EQ(Type::fromString(gfmp127.toString()), gfmp127); +} + } // namespace spu diff --git a/libspu/core/type_util.cc b/libspu/core/type_util.cc index 8261e03f..7eac207e 100644 --- a/libspu/core/type_util.cc +++ b/libspu/core/type_util.cc @@ -122,6 +122,22 @@ std::ostream& operator<<(std::ostream& os, ProtocolKind protocol) { return os; } +////////////////////////////////////////////////////////////// +// Field GFP mappings, currently only support Mersenne primes +////////////////////////////////////////////////////////////// +size_t GetMersennePrimeExp(FieldType field) { +#define CASE(Name, ScalarT, MersennePrimeExp) \ + case FieldType::Name: \ + return MersennePrimeExp; \ + break; + switch (field) { + FIELD_TO_MERSENNE_PRIME_EXP_MAP(CASE) + default: + SPU_THROW("unknown supported field {}", field); + } +#undef CASE +} + ////////////////////////////////////////////////////////////// // Field 2k types, TODO(jint) support Zq ////////////////////////////////////////////////////////////// diff --git a/libspu/core/type_util.h b/libspu/core/type_util.h index 4e70ea7e..84b04f45 100644 --- a/libspu/core/type_util.h +++ b/libspu/core/type_util.h @@ -212,7 +212,17 @@ FOREACH_PT_TYPES(CASE) std::ostream& operator<<(std::ostream& os, ProtocolKind protocol); ////////////////////////////////////////////////////////////// -// Field 2k types, TODO(jint) support Zq +// Field GFP mappings, currently only support Mersenne primes +////////////////////////////////////////////////////////////// +#define FIELD_TO_MERSENNE_PRIME_EXP_MAP(FN) \ + FN(FM32, uint32_t, 31) \ + FN(FM64, uint64_t, 61) \ + FN(FM128, uint128_t, 127) + +size_t GetMersennePrimeExp(FieldType field); + +////////////////////////////////////////////////////////////// +// Field 2k types ////////////////////////////////////////////////////////////// #define FIELD_TO_STORAGE_MAP(FN) \ FN(FM32, PT_U32) \ @@ -259,6 +269,21 @@ inline size_t SizeOf(FieldType field) { return SizeOf(GetStorageType(field)); } } \ }() +////////////////////////////////////////////////////////////// +// Field Prime types +////////////////////////////////////////////////////////////// +template +struct ScalarTypeToPrime {}; + +#define DEF_TRAITS(Field, ScalarT, Exp) \ + template <> \ + struct ScalarTypeToPrime { \ + static constexpr size_t exp = Exp; \ + static constexpr ScalarT prime = (static_cast(1) << Exp) - 1; \ + }; +FIELD_TO_MERSENNE_PRIME_EXP_MAP(DEF_TRAITS) +#undef DEF_TRAITS + ////////////////////////////////////////////////////////////// // Value range information, should it be here, at top level(jint)? ////////////////////////////////////////////////////////////// diff --git a/libspu/kernel/hal/BUILD.bazel b/libspu/kernel/hal/BUILD.bazel index ef8a61b9..b5aeef41 100644 --- a/libspu/kernel/hal/BUILD.bazel +++ b/libspu/kernel/hal/BUILD.bazel @@ -127,6 +127,7 @@ spu_cc_test( deps = [ ":fxp_approx", "//libspu/kernel:test_util", + "//libspu/mpc/utils:simulate", ], ) diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 9f3105c3..34667e84 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -201,6 +201,31 @@ Value exp_taylor(SPUContext* ctx, const Value& x) { return res; } +Value exp_prime(SPUContext* ctx, const Value& x) { + auto clamped_x = x; + auto offset = ctx->config().experimental_exp_prime_offset(); + auto fxp = ctx->getFxpBits(); + if (!ctx->config().experimental_exp_prime_disable_lower_bound()) { + // currently the bound is tied to FM128 + SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); + auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; + clamped_x = _clamp_lower(ctx, clamped_x, + constant(ctx, lower_bound, x.dtype(), x.shape())) + .setDtype(x.dtype()); + } + if (ctx->config().experimental_exp_prime_enable_upper_bound()) { + // currently the bound is tied to FM128 + SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); + auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; + clamped_x = _clamp_upper(ctx, clamped_x, + constant(ctx, upper_bound, x.dtype(), x.shape())) + .setDtype(x.dtype()); + } + + auto ret = dynDispatch(ctx, "exp_a", clamped_x); + return ret.setDtype(x.dtype()); +} + namespace { // Pade approximation of exp2(x), x is in [0, 1]. @@ -439,13 +464,22 @@ Value f_exp(SPUContext* ctx, const Value& x) { case RuntimeConfig::EXP_PADE: { // The valid input for exp_pade is [-kInputLimit, kInputLimit]. // TODO(junfeng): should merge clamp into exp_pade to save msb ops. - const float kInputLimit = 32 / std::log2(std::exp(1)); + const float kInputLimit = 32.0 / std::log2(std::exp(1)); const auto clamped_x = _clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()), constant(ctx, kInputLimit, x.dtype(), x.shape())) .setDtype(x.dtype()); return detail::exp_pade(ctx, clamped_x); } + case RuntimeConfig::EXP_PRIME: + if (ctx->hasKernel("exp_a")) { + return detail::exp_prime(ctx, x); + } else { + SPU_THROW( + "exp_a is not implemented for this protocol, currently only " + "2pc " + "semi2k is supported."); + } default: SPU_THROW("unexpected exp approximation method {}", ctx->config().fxp_exp_mode()); diff --git a/libspu/kernel/hal/fxp_approx.h b/libspu/kernel/hal/fxp_approx.h index fa401887..44c724f3 100644 --- a/libspu/kernel/hal/fxp_approx.h +++ b/libspu/kernel/hal/fxp_approx.h @@ -38,6 +38,8 @@ Value exp2_pade(SPUContext* ctx, const Value& x); // Works for range [-12.0, 18.0] Value exp_pade(SPUContext* ctx, const Value& x); +Value exp_prime(SPUContext* ctx, const Value& x); + Value tanh_chebyshev(SPUContext* ctx, const Value& x); } // namespace detail diff --git a/libspu/kernel/hal/fxp_approx_test.cc b/libspu/kernel/hal/fxp_approx_test.cc index c79dc434..d540eb2b 100644 --- a/libspu/kernel/hal/fxp_approx_test.cc +++ b/libspu/kernel/hal/fxp_approx_test.cc @@ -20,6 +20,7 @@ #include "libspu/kernel/hal/constants.h" #include "libspu/kernel/hal/type_cast.h" #include "libspu/kernel/test_util.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::kernel::hal { @@ -78,10 +79,35 @@ TEST(FxpTest, ExponentialPade) { << y; } +TEST(FxpTest, ExponentialPrime) { + std::cout << "test exp_prime" << std::endl; + spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { + RuntimeConfig conf; + conf.set_protocol(ProtocolKind::SEMI2K); + conf.set_field(FieldType::FM128); + conf.set_fxp_fraction_bits(40); + conf.set_experimental_enable_exp_prime(true); + SPUContext ctx = test::makeSPUContext(conf, lctx); + + auto offset = ctx.config().experimental_exp_prime_offset(); + auto fxp = ctx.getFxpBits(); + auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; + auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; + + xt::xarray x = xt::linspace(lower_bound, upper_bound, 4000); + + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value c = detail::exp_prime(&ctx, a); + auto y = dump_public_as(&ctx, reveal(&ctx, c)); + EXPECT_TRUE(xt::allclose(xt::exp(x), y, 0.01, 0.001)) + << xt::exp(x) << std::endl + << y; + }); +} + TEST(FxpTest, Log) { // GIVEN SPUContext ctx = test::makeSPUContext(); - xt::xarray x = {{0.05, 0.5}, {5, 50}}; // public log { diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index 6844a1b8..725fd498 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -472,14 +472,25 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b) { Value _clamp(SPUContext* ctx, const Value& x, const Value& minv, const Value& maxv) { SPU_TRACE_HAL_LEAF(ctx, x, minv, maxv); - // clamp lower bound, res = x < minv ? minv : x auto res = _mux(ctx, _less(ctx, x, minv), minv, x); - // clamp upper bound, res = res < maxv ? res, maxv return _mux(ctx, _less(ctx, res, maxv), res, maxv); } +// TODO: refactor polymorphic, and may use select functions in polymorphic +Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv) { + SPU_TRACE_HAL_LEAF(ctx, x, minv); + // clamp lower bound, res = x < minv ? minv : x + return _mux(ctx, _less(ctx, x, minv), minv, x); +} + +Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv) { + SPU_TRACE_HAL_LEAF(ctx, x, maxv); + // clamp upper bound, x = x < maxv ? x, maxv + return _mux(ctx, _less(ctx, x, maxv), x, maxv); +} + Value _constant(SPUContext* ctx, uint128_t init, const Shape& shape) { return _make_p(ctx, init, shape); } diff --git a/libspu/kernel/hal/ring.h b/libspu/kernel/hal/ring.h index 0dd7234a..f0bbb01b 100644 --- a/libspu/kernel/hal/ring.h +++ b/libspu/kernel/hal/ring.h @@ -88,6 +88,11 @@ Value _mux(SPUContext* ctx, const Value& pred, const Value& a, const Value& b); // TODO: test me Value _clamp(SPUContext* ctx, const Value& x, const Value& minv, const Value& maxv); + +Value _clamp_lower(SPUContext* ctx, const Value& x, const Value& minv); + +Value _clamp_upper(SPUContext* ctx, const Value& x, const Value& maxv); + // Make a public value from uint128_t init value. // // If the current working field has less than 128bit, the lower sizeof(field) diff --git a/libspu/mpc/aby3/oram.cc b/libspu/mpc/aby3/oram.cc index 3eae263b..7eeb2089 100644 --- a/libspu/mpc/aby3/oram.cc +++ b/libspu/mpc/aby3/oram.cc @@ -440,14 +440,14 @@ void OramContext::onehotB2A(KernelEvalContext *ctx, DpfGenCtrl ctrl) { const std::vector v = convert_help_v[dpf_idx]; std::for_each(e.begin(), e.end(), [&](T ele) { pm += ele; }); std::for_each(v.begin(), v.end(), [&](T ele) { F -= ele; }); - auto blinded_pm = pm + r[0]; + T blinded_pm = pm + r[0]; // open blinded_pm comm->sendAsync(dst_rank, {blinded_pm}, "open(blinded_pm)"); blinded_pm += comm->recv(dst_rank, "open(blinded_pm)")[0]; auto pm_mul_F = mul2pc(ctx, {pm}, {F}, static_cast(ctrl)); - auto blinded_F = pm_mul_F[0] + r[0]; + T blinded_F = pm_mul_F[0] + r[0]; // open blinded_F comm->sendAsync(dst_rank, {blinded_F}, "open(blinded_F)"); diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index 26f9fbd1..5abf60dc 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -57,7 +57,11 @@ TEST_P(CompareProtTest, Compare) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } }); NdArrayRef cmp_oup[2]; @@ -108,7 +112,11 @@ TEST_P(CompareProtTest, CompareBitWidth) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); }); @@ -178,7 +186,11 @@ TEST_P(CompareProtTest, WithEq) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } }); NdArrayRef cmp_oup[2]; @@ -237,7 +249,11 @@ TEST_P(CompareProtTest, WithEqBitWidth) { xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; - xinp[2] = 1000; + if constexpr (std::is_same_v) { + xinp[2] = 100; + } else { + xinp[2] = 1000; + } pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); }); diff --git a/libspu/mpc/common/BUILD.bazel b/libspu/mpc/common/BUILD.bazel index 375db84d..f22bc908 100644 --- a/libspu/mpc/common/BUILD.bazel +++ b/libspu/mpc/common/BUILD.bazel @@ -42,6 +42,7 @@ spu_cc_library( hdrs = ["communicator.h"], deps = [ "//libspu/core:object", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/link:context", "@yacl//yacl/link/algorithm:allgather", @@ -88,6 +89,7 @@ spu_cc_library( hdrs = ["prg_tensor.h"], deps = [ "//libspu/core:ndarray_ref", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/tools:prg", ], diff --git a/libspu/mpc/common/communicator.cc b/libspu/mpc/common/communicator.cc index 3e41dfe8..b7dc4089 100644 --- a/libspu/mpc/common/communicator.cc +++ b/libspu/mpc/common/communicator.cc @@ -14,6 +14,7 @@ #include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -53,7 +54,11 @@ NdArrayRef Communicator::allReduce(ReduceOp op, const NdArrayRef& in, auto arr = NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); if (op == ReduceOp::ADD) { - ring_add_(res, arr); + if (in.eltype().isa()) { + gfmp_add_mod_(res, arr); + } else { + ring_add_(res, arr); + } } else if (op == ReduceOp::XOR) { ring_xor_(res, arr); } else { @@ -86,7 +91,11 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), makeCompactStrides(in.shape()), kOffset); if (op == ReduceOp::ADD) { - ring_add_(res, arr); + if (in.eltype().isa()) { + gfmp_add_mod_(res, arr); + } else { + ring_add_(res, arr); + } } else if (op == ReduceOp::XOR) { ring_xor_(res, arr); } else { @@ -94,7 +103,6 @@ NdArrayRef Communicator::reduce(ReduceOp op, const NdArrayRef& in, size_t root, } } } - stats_.latency += 1; stats_.comm += in.numel() * in.elsize(); @@ -116,6 +124,42 @@ NdArrayRef Communicator::rotate(const NdArrayRef& in, std::string_view tag) { makeCompactStrides(in.shape()), kOffset); } +std::vector Communicator::gather(const NdArrayRef& in, size_t root, + std::string_view tag) { + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + array.numel() * array.elsize()); + auto bufs = yacl::link::Gather(lctx_, bv, root, tag); + + stats_.latency += 1; + stats_.comm += array.numel() * array.elsize(); + + auto res = std::vector(getWorldSize()); + if (root == getRank()) { + SPU_ENFORCE_EQ(bufs.size(), getWorldSize()); + for (size_t idx = 0; idx < bufs.size(); idx++) { + res[idx] = + NdArrayRef(stealBuffer(std::move(bufs[idx])), in.eltype(), in.shape(), + makeCompactStrides(in.shape()), kOffset); + } + } + return res; +} + +NdArrayRef Communicator::broadcast(const NdArrayRef& in, size_t root, + std::string_view tag) { + const auto array = getOrCreateCompactArray(in); + yacl::ByteContainerView bv(reinterpret_cast(array.data()), + array.elsize() * array.numel()); + auto buf = yacl::link::Broadcast(lctx_, bv, root, tag); + + stats_.latency += 1; + stats_.comm += in.elsize() * in.numel(); + + return NdArrayRef(stealBuffer(std::move(buf)), in.eltype(), in.shape(), + makeCompactStrides(in.shape()), kOffset); +} + void Communicator::sendAsync(size_t dst_rank, const NdArrayRef& in, std::string_view tag) { const auto array = getOrCreateCompactArray(in); @@ -132,4 +176,4 @@ NdArrayRef Communicator::recv(size_t src_rank, const Type& eltype, return NdArrayRef(stealBuffer(std::move(buf)), eltype, {numel}, {1}, kOffset); } -} // namespace spu::mpc +} // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/common/communicator.h b/libspu/mpc/common/communicator.h index 42a84097..f6103937 100644 --- a/libspu/mpc/common/communicator.h +++ b/libspu/mpc/common/communicator.h @@ -103,6 +103,11 @@ class Communicator : public State { NdArrayRef allReduce(ReduceOp op, const NdArrayRef& in, std::string_view tag); + std::vector gather(const NdArrayRef& in, size_t root, + std::string_view tag); + + NdArrayRef broadcast(const NdArrayRef& in, size_t root, std::string_view tag); + NdArrayRef reduce(ReduceOp op, const NdArrayRef& in, size_t root, std::string_view tag); diff --git a/libspu/mpc/common/prg_tensor.h b/libspu/mpc/common/prg_tensor.h index 54de9171..d3b3e704 100644 --- a/libspu/mpc/common/prg_tensor.h +++ b/libspu/mpc/common/prg_tensor.h @@ -15,6 +15,7 @@ #pragma once #include "libspu/core/ndarray_ref.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc { @@ -22,28 +23,46 @@ namespace spu::mpc { using PrgSeed = uint128_t; using PrgCounter = uint64_t; +// Gfmp is regarded as word +// standing for Galois Field with Mersenne Prime. +enum class ElementType { kRing, kGfmp }; + struct PrgArrayDesc { Shape shape; FieldType field; PrgCounter prg_counter; + ElementType eltype; }; inline NdArrayRef prgCreateArray(FieldType field, const Shape& shape, PrgSeed seed, PrgCounter* counter, - PrgArrayDesc* desc) { + PrgArrayDesc* desc, + ElementType eltype = ElementType::kRing) { if (desc != nullptr) { - *desc = {Shape(shape.begin(), shape.end()), field, *counter}; + *desc = {Shape(shape.begin(), shape.end()), field, *counter, eltype}; + } + if (eltype == ElementType::kGfmp) { + return gfmp_rand(field, shape, seed, counter); + } else { + return ring_rand(field, shape, seed, counter); } - return ring_rand(field, shape, seed, counter); } inline NdArrayRef prgReplayArray(PrgSeed seed, const PrgArrayDesc& desc) { PrgCounter counter = desc.prg_counter; - return ring_rand(desc.field, desc.shape, seed, &counter); + if (desc.eltype == ElementType::kGfmp) { + return gfmp_rand(desc.field, desc.shape, seed, &counter); + } else { + return ring_rand(desc.field, desc.shape, seed, &counter); + } } inline NdArrayRef prgReplayArrayMutable(PrgSeed seed, PrgArrayDesc& desc) { - return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter); + if (desc.eltype == ElementType::kGfmp) { + return gfmp_rand(desc.field, desc.shape, seed, &desc.prg_counter); + } else { + return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter); + } } } // namespace spu::mpc diff --git a/libspu/mpc/semi2k/BUILD.bazel b/libspu/mpc/semi2k/BUILD.bazel index e845ee73..dfef1ec7 100644 --- a/libspu/mpc/semi2k/BUILD.bazel +++ b/libspu/mpc/semi2k/BUILD.bazel @@ -46,6 +46,34 @@ spu_cc_library( ], ) +spu_cc_library( + name = "prime_utils", + srcs = ["prime_utils.cc"], + hdrs = ["prime_utils.h"], + deps = [ + ":state", + ":type", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:gfmp", + "//libspu/mpc/utils:ring_ops", + ], +) + +spu_cc_library( + name = "exp", + srcs = ["exp.cc"], + hdrs = ["exp.h"], + deps = [ + ":prime_utils", + ":state", + ":type", + "//libspu/mpc:kernel", + "//libspu/mpc/utils:gfmp", + "//libspu/mpc/utils:ring_ops", + ], +) + spu_cc_library( name = "conversion", srcs = ["conversion.cc"], @@ -83,6 +111,7 @@ spu_cc_library( ":arithmetic", ":boolean", ":conversion", + ":exp", ":permute", ":state", "//libspu/mpc/common:prg_state", @@ -94,7 +123,10 @@ spu_cc_test( name = "protocol_test", srcs = ["protocol_test.cc"], deps = [ + ":exp", + ":prime_utils", ":protocol", + ":type", "//libspu/mpc:ab_api_test", "//libspu/mpc:api_test", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server", diff --git a/libspu/mpc/semi2k/beaver/beaver_cache.h b/libspu/mpc/semi2k/beaver/beaver_cache.h index 3c9d9ade..5de345ff 100644 --- a/libspu/mpc/semi2k/beaver/beaver_cache.h +++ b/libspu/mpc/semi2k/beaver/beaver_cache.h @@ -32,9 +32,11 @@ namespace spu::mpc::semi2k { class BeaverCache { public: + // clang-format off BeaverCache() : cache_db_(fmt::format("BeaverCache.{}.{}.{}", getpid(), fmt::ptr(this), std::random_device()())) {}; + // clang-format on ~BeaverCache() { db_.reset(); try { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel index 05e589f3..5f0bd1e1 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/BUILD.bazel @@ -25,6 +25,7 @@ spu_cc_library( "//libspu/mpc/semi2k/beaver:beaver_interface", "//libspu/mpc/semi2k/beaver/beaver_impl/trusted_party", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_stream", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@com_github_microsoft_seal//:seal", "@yacl//yacl/link", @@ -40,6 +41,7 @@ spu_cc_test( ":beaver_ttp", "//libspu/core:xt_helper", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_server", + "//libspu/mpc/utils:gfmp", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:simulate", "@com_google_googletest//:gtest", @@ -55,6 +57,7 @@ spu_cc_library( "//libspu/mpc/semi2k/beaver:beaver_interface", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:beaver_stream", "//libspu/mpc/semi2k/beaver/beaver_impl/ttp_server:service_cc_proto", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:ring_ops", "@yacl//yacl/crypto/pke:sm2_enc", "@yacl//yacl/link", diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc index 78cc71d2..300a2f6e 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_test.cc @@ -24,6 +24,7 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" +#include "libspu/mpc/utils/gfmp.h" #include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -158,6 +159,60 @@ std::vector open_buffer(std::vector& in_buffers, } return ret; } + +template +std::vector open_buffer_gfmp(std::vector& in_buffers, + FieldType k_field, + const std::vector& shapes, + size_t k_world_size, bool add_open) { + std::vector ret; + + auto reduce = [&](NdArrayRef& r, yacl::Buffer& b) { + if (b.size() == 0) { + return; + } + EXPECT_EQ(b.size(), r.shape().numel() * SizeOf(k_field)); + NdArrayRef a(std::make_shared(std::move(b)), ret[0].eltype(), + r.shape()); + auto Ta = r.eltype(); + gfmp_add_mod_(r, a.as(Ta)); + }; + if constexpr (std::is_same_v) { + ret.resize(3); + SPU_ENFORCE(shapes.size() == 3); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& [a_buf, b_buf, c_buf] = in_buffers[r]; + reduce(ret[0], a_buf); + reduce(ret[1], b_buf); + reduce(ret[2], c_buf); + } + } else if constexpr (std::is_same_v) { + ret.resize(2); + SPU_ENFORCE(shapes.size() == 2); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& [a_buf, b_buf] = in_buffers[r]; + reduce(ret[0], a_buf); + reduce(ret[1], b_buf); + } + } else if constexpr (std::is_same_v) { + ret.resize(1); + SPU_ENFORCE(shapes.size() == 1); + for (size_t i = 0; i < shapes.size(); i++) { + ret[i] = gfmp_zeros(k_field, shapes[i]); + } + for (Rank r = 0; r < k_world_size; r++) { + auto& a_buf = in_buffers[r]; + reduce(ret[0], a_buf); + } + } + return ret; +} } // namespace TEST_P(BeaverTest, Mul_large) { @@ -215,11 +270,11 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -242,10 +297,10 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -269,12 +324,12 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -299,14 +354,14 @@ TEST_P(BeaverTest, Mul_large) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { // mul not support transpose. // enforce ne - EXPECT_NE(_cache_a[idx], _a[idx]); - EXPECT_NE(_cache_b[idx], _b[idx]); + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -370,11 +425,11 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -397,10 +452,10 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -424,12 +479,12 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -454,14 +509,14 @@ TEST_P(BeaverTest, Mul) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); NdArrayView _b(open[1]); - NdArrayView _cache_a(x_cache); - NdArrayView _cache_b(y_cache); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { // mul not support transpose. // enforce ne - EXPECT_NE(_cache_a[idx], _a[idx]); - EXPECT_NE(_cache_b[idx], _b[idx]); + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); @@ -470,6 +525,176 @@ TEST_P(BeaverTest, Mul) { } } +TEST_P(BeaverTest, MulGfmp) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const int64_t kMaxDiff = std::get<3>(GetParam()); + const size_t adjust_rank = std::get<4>(GetParam()); + const int64_t kNumel = 7; + + std::vector triples(kWorldSize); + + std::vector x_desc(kWorldSize); + std::vector y_desc(kWorldSize); + NdArrayRef x_cache; + NdArrayRef y_cache; + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + auto prime = ScalarTypeToPrime::prime; + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + + x_cache = open[0]; + y_cache = open[1]; + } + { + utils::simulate(kWorldSize, + [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + nullptr, ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _a_cache(x_cache); + NdArrayView _b(open[1]); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + auto prime = ScalarTypeToPrime::prime; + EXPECT_EQ(_a_cache[idx], _a[idx]); + auto t = mul_mod(_a[idx], _b[idx]) % prime; + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + y_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, nullptr, &y_desc[lctx->Rank()], + ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + EXPECT_EQ(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::Replay; + y_desc[lctx->Rank()].status = Beaver::Replay; + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } + { + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx, ttp_options_, adjust_rank); + x_desc[lctx->Rank()].status = Beaver::TransposeReplay; + y_desc[lctx->Rank()].status = Beaver::TransposeReplay; + // mul not support transpose. + triples[lctx->Rank()] = + beaver->Mul(kField, kNumel, &x_desc[lctx->Rank()], + &y_desc[lctx->Rank()], ElementType::kGfmp); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + auto open = open_buffer_gfmp( + triples, kField, std::vector(3, {kNumel}), kWorldSize, true); + + DISPATCH_ALL_FIELDS(kField, [&]() { + NdArrayView _a(open[0]); + NdArrayView _b(open[1]); + NdArrayView _a_cache(x_cache); + NdArrayView _b_cache(y_cache); + NdArrayView _c(open[2]); + for (auto idx = 0; idx < _a.numel(); idx++) { + // mul not support transpose. + // enforce ne + EXPECT_NE(_a_cache[idx], _a[idx]); + EXPECT_NE(_b_cache[idx], _b[idx]); + auto t = mul_mod(_a[idx], _b[idx]); + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + auto prime = ScalarTypeToPrime::prime; + auto error_mod_p = static_cast(err) % prime; + EXPECT_LE(error_mod_p, kMaxDiff); + } + }); + } +} + TEST_P(BeaverTest, And) { const auto factory = std::get<0>(GetParam()).first; const size_t kWorldSize = std::get<1>(GetParam()); @@ -566,11 +791,11 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(x_cache, open[1]); DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -593,11 +818,11 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(open[0], y_cache); DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -621,14 +846,14 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(x_cache, y_cache); DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); - NdArrayView _cache_b(y_cache); + NdArrayView _b_cache(y_cache); NdArrayView _r(res); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -653,16 +878,16 @@ TEST_P(BeaverTest, Dot) { DISPATCH_ALL_FIELDS(kField, [&]() { auto transpose_a = open[0].transpose(); NdArrayView _a(transpose_a); - NdArrayView _cache_a(y_cache); + NdArrayView _a_cache(y_cache); auto transpose_b = open[1].transpose(); NdArrayView _b(transpose_b); - NdArrayView _cache_b(x_cache); + NdArrayView _b_cache(x_cache); auto transpose_r = res.transpose(); NdArrayView _r(transpose_r); NdArrayView _c(open[2]); for (auto idx = 0; idx < res.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); - EXPECT_EQ(_cache_b[idx], _b[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); + EXPECT_EQ(_b_cache[idx], _b[idx]); auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); } @@ -685,11 +910,11 @@ TEST_P(BeaverTest, Dot) { DISPATCH_ALL_FIELDS(kField, [&]() { NdArrayView _a(open[0]); - NdArrayView _cache_a(x_cache); + NdArrayView _a_cache(x_cache); NdArrayView _b(open[1]); NdArrayView _c(open[2]); for (auto idx = 0; idx < _a.numel(); idx++) { - EXPECT_EQ(_cache_a[idx], _a[idx]); + EXPECT_EQ(_a_cache[idx], _a[idx]); auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; EXPECT_LE(err, kMaxDiff); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc index 402d1c0d..f876209d 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc @@ -23,6 +23,7 @@ #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { @@ -32,9 +33,9 @@ namespace { inline size_t CeilDiv(size_t a, size_t b) { return (a + b - 1) / b; } void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, - const std::vector& encrypted_seeds, - PrgCounter counter, PrgSeed self_seed) { + PrgCounter counter, PrgSeed self_seed, + ElementType eltype = ElementType::kRing) { if (desc == nullptr || desc->status != Beaver::Init) { return; } @@ -43,6 +44,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, desc->prg_counter = counter; desc->encrypted_seeds = encrypted_seeds; desc->seed = self_seed; + desc->eltype = eltype; } } // namespace @@ -67,7 +69,8 @@ BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, ReplayDesc* x_desc, - ReplayDesc* y_desc) { + ReplayDesc* y_desc, + ElementType eltype) { std::vector ops(3); Shape shape({size, 1}); std::vector> replay_seeds(3); @@ -75,9 +78,13 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, auto if_replay = [&](const ReplayDesc* replay_desc, size_t idx) { if (replay_desc == nullptr || replay_desc->status != Beaver::Replay) { ops[idx].seeds = seeds_; - return prgCreateArray(field, shape, seed_, &counter_, &ops[idx].desc); + // enforce the eltypes in ops + ops[idx].desc.eltype = eltype; + return prgCreateArray(field, shape, seed_, &counter_, &ops[idx].desc, + eltype); } else { SPU_ENFORCE(replay_desc->field == field); + SPU_ENFORCE(replay_desc->eltype == eltype); SPU_ENFORCE(replay_desc->size == size); if (lctx_->Rank() == 0) { SPU_ENFORCE(replay_desc->encrypted_seeds.size() == lctx_->WorldSize()); @@ -90,25 +97,31 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, } ops[idx].seeds = replay_seeds[idx]; ops[idx].desc.field = field; + ops[idx].desc.eltype = eltype; ops[idx].desc.shape = shape; ops[idx].desc.prg_counter = replay_desc->prg_counter; } PrgCounter tmp_counter = replay_desc->prg_counter; return prgCreateArray(field, shape, replay_desc->seed, &tmp_counter, - nullptr); + nullptr, eltype); } }; - FillReplayDesc(x_desc, field, size, seeds_buff_, counter_, seed_); + FillReplayDesc(x_desc, field, size, seeds_buff_, counter_, seed_, eltype); auto a = if_replay(x_desc, 0); - FillReplayDesc(y_desc, field, size, seeds_buff_, counter_, seed_); + FillReplayDesc(y_desc, field, size, seeds_buff_, counter_, seed_, eltype); auto b = if_replay(y_desc, 1); - auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[2].desc); + auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[2].desc, eltype); if (lctx_->Rank() == 0) { ops[2].seeds = seeds_; auto adjust = TrustedParty::adjustMul(absl::MakeSpan(ops)); - ring_add_(c, adjust); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjust.as(T)); + } else { + ring_add_(c, adjust); + } } Triple ret; @@ -119,6 +132,37 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size, return ret; } +BeaverTfpUnsafe::Pair BeaverTfpUnsafe::MulPriv(FieldType field, int64_t size, + ElementType eltype) { + std::vector ops(2); + Shape shape({size, 1}); + + ops[0].seeds = seeds_; + // enforce the eltypes in ops + ops[0].desc.eltype = eltype; + ops[1].desc.eltype = eltype; + auto a_or_b = + prgCreateArray(field, shape, seed_, &counter_, &ops[0].desc, eltype); + auto c = prgCreateArray(field, shape, seed_, &counter_, &ops[1].desc, eltype); + + if (lctx_->Rank() == 0) { + ops[1].seeds = seeds_; + auto adjust = TrustedParty::adjustMulPriv(absl::MakeSpan(ops)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjust.as(T)); + } else { + ring_add_(c, adjust); + } + } + + Pair ret; + std::get<0>(ret) = std::move(*a_or_b.buf()); + std::get<1>(ret) = std::move(*c.buf()); + + return ret; +} + BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size, ReplayDesc* x_desc) { std::vector ops(2); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h index 9ca11bca..2f26a716 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.h @@ -45,7 +45,11 @@ class BeaverTfpUnsafe final : public Beaver { explicit BeaverTfpUnsafe(std::shared_ptr lctx); Triple Mul(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) override; + ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) override; + + Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) override; Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) override; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc index c29fa32b..be2e9e86 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc @@ -24,6 +24,7 @@ #include "libspu/mpc/common/prg_tensor.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_stream.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/ring_ops.h" namespace brpc { @@ -41,7 +42,8 @@ inline size_t CeilDiv(size_t a, size_t b) { return (a + b - 1) / b; } void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, const std::vector& encrypted_seeds, - PrgCounter counter, PrgSeed self_seed) { + PrgCounter counter, PrgSeed self_seed, + ElementType eltype = ElementType::kRing) { if (desc == nullptr || desc->status != Beaver::Init) { return; } @@ -50,6 +52,7 @@ void FillReplayDesc(Beaver::ReplayDesc* desc, FieldType field, int64_t size, desc->prg_counter = counter; desc->encrypted_seeds = encrypted_seeds; desc->seed = self_seed; + desc->eltype = eltype; } template @@ -61,11 +64,15 @@ AdjustRequest BuildAdjustRequest( SPU_ENFORCE(!descs.empty()); uint32_t field_size; + ElementType eltype = ElementType::kRing; + for (size_t i = 0; i < descs.size(); i++) { const auto& desc = descs[i]; auto* input = ret.add_prg_inputs(); input->set_prg_count(desc.prg_counter); field_size = SizeOf(desc.field); + eltype = desc.eltype; + input->set_buffer_len(desc.shape.numel() * SizeOf(desc.field)); absl::Span seeds; @@ -83,6 +90,14 @@ AdjustRequest BuildAdjustRequest( beaver::ttp_server::AdjustAndRequest>) { ret.set_field_size(field_size); } + if constexpr (std::is_same_v || + std::is_same_v) { + if (eltype == ElementType::kGfmp) + ret.set_element_type(beaver::ttp_server::ElType::GFMP); + } + return ret; } @@ -223,6 +238,10 @@ std::vector RpcCall( if constexpr (std::is_same_v) { stub.AdjustMul(&cntl, &req, &rsp, nullptr); + } else if constexpr (std::is_same_v< + AdjustRequest, + beaver::ttp_server::AdjustMulPrivRequest>) { + stub.AdjustMulPriv(&cntl, &req, &rsp, nullptr); } else if constexpr (std::is_same_v< AdjustRequest, beaver::ttp_server::AdjustSquareRequest>) { @@ -340,15 +359,18 @@ BeaverTtp::BeaverTtp(std::shared_ptr lctx, Options ops) "BEAVER_TTP:SYNC_ENCRYPTED_SEEDS"); } +// TODO: kGfmp supports more operations BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, - ReplayDesc* x_desc, ReplayDesc* y_desc) { + ReplayDesc* x_desc, ReplayDesc* y_desc, + ElementType eltype) { std::vector descs(3); std::vector> descs_seed(3, encrypted_seeds_); Shape shape({size, 1}); auto if_replay = [&](const ReplayDesc* replay_desc, size_t idx) { if (replay_desc == nullptr || replay_desc->status != Beaver::Replay) { - return prgCreateArray(field, shape, seed_, &counter_, &descs[idx]); + return prgCreateArray(field, shape, seed_, &counter_, &descs[idx], + eltype); } else { SPU_ENFORCE(replay_desc->field == field); SPU_ENFORCE(replay_desc->size == size); @@ -356,27 +378,35 @@ BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, if (lctx_->Rank() == options_.adjust_rank) { descs_seed[idx] = replay_desc->encrypted_seeds; descs[idx].field = field; + descs[idx].eltype = eltype; descs[idx].shape = shape; descs[idx].prg_counter = replay_desc->prg_counter; } PrgCounter tmp_counter = replay_desc->prg_counter; return prgCreateArray(field, shape, replay_desc->seed, &tmp_counter, - &descs[idx]); + &descs[idx], eltype); } }; - FillReplayDesc(x_desc, field, size, encrypted_seeds_, counter_, seed_); + FillReplayDesc(x_desc, field, size, encrypted_seeds_, counter_, seed_, + eltype); auto a = if_replay(x_desc, 0); - FillReplayDesc(y_desc, field, size, encrypted_seeds_, counter_, seed_); + FillReplayDesc(y_desc, field, size, encrypted_seeds_, counter_, seed_, + eltype); auto b = if_replay(y_desc, 1); - auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[2]); + auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[2], eltype); if (lctx_->Rank() == options_.adjust_rank) { auto req = BuildAdjustRequest( descs, descs_seed); auto adjusts = RpcCall(channel_, req, field); SPU_ENFORCE_EQ(adjusts.size(), 1U); - ring_add_(c, adjusts[0].reshape(shape)); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjusts[0].reshape(shape).as(T)); + } else { + ring_add_(c, adjusts[0].reshape(shape)); + } } Triple ret; @@ -387,6 +417,34 @@ BeaverTtp::Triple BeaverTtp::Mul(FieldType field, int64_t size, return ret; } +BeaverTtp::Pair BeaverTtp::MulPriv(FieldType field, int64_t size, + ElementType eltype) { + std::vector descs(2); + std::vector> descs_seed(2, encrypted_seeds_); + Shape shape({size, 1}); + auto a_or_b = + prgCreateArray(field, shape, seed_, &counter_, &descs[0], eltype); + auto c = prgCreateArray(field, shape, seed_, &counter_, &descs[1], eltype); + if (lctx_->Rank() == options_.adjust_rank) { + auto req = BuildAdjustRequest( + descs, descs_seed); + auto adjusts = RpcCall(channel_, req, field); + SPU_ENFORCE_EQ(adjusts.size(), 1U); + if (eltype == ElementType::kGfmp) { + auto T = c.eltype(); + gfmp_add_mod_(c, adjusts[0].reshape(shape).as(T)); + } else { + ring_add_(c, adjusts[0].reshape(shape)); + } + } + + Pair ret; + std::get<0>(ret) = std::move(*a_or_b.buf()); + std::get<1>(ret) = std::move(*c.buf()); + + return ret; +} + BeaverTtp::Pair BeaverTtp::Square(FieldType field, int64_t size, ReplayDesc* x_desc) { std::vector descs(2); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h index ecb39237..501d5eac 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.h @@ -66,7 +66,11 @@ class BeaverTtp final : public Beaver { ~BeaverTtp() override = default; Triple Mul(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) override; + ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) override; + + Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) override; Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) override; diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel index 5a503213..2613516a 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/BUILD.bazel @@ -21,7 +21,9 @@ spu_cc_library( srcs = ["trusted_party.cc"], hdrs = ["trusted_party.h"], deps = [ + "//libspu/core:type_util", "//libspu/mpc/common:prg_tensor", + "//libspu/mpc/utils:gfmp_ops", "//libspu/mpc/utils:permute", "//libspu/mpc/utils:ring_ops", ], diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc index 1ff405ad..bdeb2811 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.cc @@ -14,22 +14,32 @@ #include "libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h" +#include "libspu/core/type_util.h" +#include "libspu/mpc/common/prg_tensor.h" +#include "libspu/mpc/utils/gfmp_ops.h" #include "libspu/mpc/utils/permute.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::semi2k { namespace { +enum class ReduceOp : uint8_t { + ADD = 0, + XOR = 1, + MUL = 2, +}; + enum class RecOp : uint8_t { ADD = 0, XOR = 1, }; -std::vector reconstruct(RecOp op, - absl::Span ops) { +std::vector reduce(ReduceOp op, + absl::Span ops) { std::vector rs(ops.size()); const auto world_size = ops[0].seeds.size(); + for (size_t rank = 0; rank < world_size; rank++) { for (size_t idx = 0; idx < ops.size(); idx++) { // FIXME: TTP adjuster server and client MUST have same endianness. @@ -43,12 +53,25 @@ std::vector reconstruct(RecOp op, if (rank == 0) { rs[idx] = t; } else { - if (op == RecOp::ADD) { - ring_add_(rs[idx], t); - } else if (op == RecOp::XOR) { + if (op == ReduceOp::ADD) { + if (ops[idx].desc.eltype == ElementType::kGfmp) { + // TODO: generalize the reduction + gfmp_add_mod_(rs[idx], t); + } else { + ring_add_(rs[idx], t); + } + } else if (op == ReduceOp::XOR) { + // gfmp has no xor inplementation ring_xor_(rs[idx], t); + } else if (op == ReduceOp::MUL) { + if (ops[idx].desc.eltype == ElementType::kGfmp) { + // TODO: generalize the reduction + gfmp_mul_mod_(rs[idx], t); + } else { + ring_mul_(rs[idx], t); + } } else { - SPU_ENFORCE("not supported reconstruct op"); + SPU_ENFORCE("not supported reduction op"); } } } @@ -57,11 +80,17 @@ std::vector reconstruct(RecOp op, return rs; } +std::vector reconstruct(RecOp op, + absl::Span ops) { + return reduce(ReduceOp(op), ops); +} + void checkOperands(absl::Span ops, bool skip_shape = false, bool allow_transpose = false) { for (size_t idx = 1; idx < ops.size(); idx++) { SPU_ENFORCE(skip_shape || ops[0].desc.shape == ops[idx].desc.shape); SPU_ENFORCE(allow_transpose || ops[0].transpose == false); + SPU_ENFORCE(ops[0].desc.eltype == ops[idx].desc.eltype); SPU_ENFORCE(ops[0].desc.field == ops[idx].desc.field); SPU_ENFORCE(ops[0].seeds.size() == ops[idx].seeds.size(), "{} <> {}", ops[0].seeds.size(), ops[idx].seeds.size()); @@ -70,13 +99,41 @@ void checkOperands(absl::Span ops, } // namespace +// TODO: gfmp support more operations NdArrayRef TrustedParty::adjustMul(absl::Span ops) { SPU_ENFORCE_EQ(ops.size(), 3U); checkOperands(ops); auto rs = reconstruct(RecOp::ADD, ops); // adjust = rs[0] * rs[1] - rs[2]; - return ring_sub(ring_mul(rs[0], rs[1]), rs[2]); + if (ops[0].desc.eltype == ElementType::kGfmp) { + return gfmp_sub_mod(gfmp_mul_mod(rs[0], rs[1]), rs[2]); + } else { + return ring_sub(ring_mul(rs[0], rs[1]), rs[2]); + } +} + +// ops are [a_or_b, c] +// P0 generate a, c0 +// P1 generate b, c1 +// The adjustment is ab - (c0 + c1), +// which only needs to be sent to adjust party, e.g. P0. +// P0 with adjust is ab - c1 = ab - (c0 + c1) + c0 +// Therefore, +// P0 holds: a, ab - c1 +// P1 holds: b, c1 +NdArrayRef TrustedParty::adjustMulPriv(absl::Span ops) { + SPU_ENFORCE_EQ(ops.size(), 2U); + checkOperands(ops); + + auto ab = reduce(ReduceOp::MUL, ops.subspan(0, 1))[0]; + auto c = reconstruct(RecOp::ADD, ops.subspan(1, 1))[0]; + // adjust = ab - c; + if (ops[0].desc.eltype == ElementType::kGfmp) { + return gfmp_sub_mod(ab, c); + } else { + return ring_sub(ab, c); + } } NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { @@ -84,6 +141,7 @@ NdArrayRef TrustedParty::adjustSquare(absl::Span ops) { auto rs = reconstruct(RecOp::ADD, ops); // adjust = rs[0] * rs[0] - rs[1]; + return ring_sub(ring_mul(rs[0], rs[0]), rs[1]); } diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h index 55a412e9..60098256 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h +++ b/libspu/mpc/semi2k/beaver/beaver_impl/trusted_party/trusted_party.h @@ -33,6 +33,8 @@ class TrustedParty { static NdArrayRef adjustMul(absl::Span); + static NdArrayRef adjustMulPriv(absl::Span); + static NdArrayRef adjustSquare(absl::Span); static NdArrayRef adjustDot(absl::Span); diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc index 30a72277..2a5136ec 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.cc @@ -52,7 +52,8 @@ template std::tuple, std::vector>, size_t> BuildOperand(const AdjustRequest& req, uint32_t field_size, - const std::unique_ptr& decryptor) { + const std::unique_ptr& decryptor, + ElementType eltype) { std::vector ops; std::vector> seeds; size_t pad_length = 0; @@ -140,7 +141,7 @@ BuildOperand(const AdjustRequest& req, uint32_t field_size, } seeds.emplace_back(std::move(seed)); ops.push_back( - TrustedParty::Operand{{shape, type, prg_count}, seeds.back()}); + TrustedParty::Operand{{shape, type, prg_count, eltype}, seeds.back()}); } if constexpr (std::is_same_v) { @@ -305,6 +306,9 @@ std::vector AdjustImpl(const AdjustRequest& req, if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustMul(ops); ret.push_back(std::move(adjust)); + } else if constexpr (std::is_same_v) { + auto adjust = TrustedParty::adjustMulPriv(ops); + ret.push_back(std::move(adjust)); } else if constexpr (std::is_same_v) { auto adjust = TrustedParty::adjustSquare(ops); ret.push_back(std::move(adjust)); @@ -357,7 +361,17 @@ void AdjustAndSend( } else { field_size = req.field_size(); } - auto [ops, seeds, pad_length] = BuildOperand(req, field_size, decryptor); + ElementType eltype = ElementType::kRing; + // enable eltype for selected requests here + // later all requests may support gfmp + if constexpr (std::is_same_v || + std::is_same_v) { + if (req.element_type() == ElType::GFMP) { + eltype = ElementType::kGfmp; + } + } + auto [ops, seeds, pad_length] = + BuildOperand(req, field_size, decryptor, eltype); if constexpr (std::is_same_v || std::is_same_v) { @@ -475,6 +489,12 @@ class ServiceImpl final : public BeaverService { Adjust(controller, req, rsp, done); } + void AdjustMulPriv(::google::protobuf::RpcController* controller, + const AdjustMulPrivRequest* req, AdjustResponse* rsp, + ::google::protobuf::Closure* done) override { + Adjust(controller, req, rsp, done); + } + void AdjustSquare(::google::protobuf::RpcController* controller, const AdjustSquareRequest* req, AdjustResponse* rsp, ::google::protobuf::Closure* done) override { diff --git a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto index 6b1b3675..23fd3025 100644 --- a/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto +++ b/libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/service.proto @@ -25,6 +25,13 @@ enum ErrorCode { StreamAcceptError = 3; } +// The type of element in the field. +// Match the enum in libspu/mpc/common/prg_tensor.h +enum ElType { + UNSPECIFIED = 0; + RING = 1; + GFMP = 2; +} // PRG generated buffer metainfo. // BeaverService replay PRG to generate same buffer using each party's prg_seed // encrypted by server's public key. PrgBufferMeta represent {world_size} @@ -42,6 +49,8 @@ service BeaverService { // V1 adjust ops rpc AdjustMul(AdjustMulRequest) returns (AdjustResponse); + rpc AdjustMulPriv(AdjustMulPrivRequest) returns (AdjustResponse); + rpc AdjustSquare(AdjustSquareRequest) returns (AdjustResponse); rpc AdjustDot(AdjustDotRequest) returns (AdjustResponse); @@ -69,6 +78,27 @@ message AdjustMulRequest { // adjust_c = ra * rb - rc // make // ra * rb = (adjust_c + rc) + + // element type supported: "GFMP", "RING" + ElType element_type = 3; + // if element type is "GFMP" then all ring ops will be changed to gfmp +} + +message AdjustMulPrivRequest { + // input 2 prg buffer + // first is a or b [one party holds a slice, another b slice] + // second is c + repeated PrgBufferMeta prg_inputs = 1; + // What field size should be used to interpret buffer content + uint32 field_size = 2; + // output + // adjust_c = a * b - rc + // make + // a * b = (adjust_c + rc) + + // element type supported: "GFMP", "RING" + ElType element_type = 3; + // if element type is "GFMP" then all ring ops will be changed to gfmp } message AdjustSquareRequest { diff --git a/libspu/mpc/semi2k/beaver/beaver_interface.h b/libspu/mpc/semi2k/beaver/beaver_interface.h index d610f380..89c58267 100644 --- a/libspu/mpc/semi2k/beaver/beaver_interface.h +++ b/libspu/mpc/semi2k/beaver/beaver_interface.h @@ -41,6 +41,7 @@ class Beaver { std::vector encrypted_seeds; int64_t size; FieldType field; + ElementType eltype; }; using Array = yacl::Buffer; @@ -50,8 +51,11 @@ class Beaver { virtual ~Beaver() = default; virtual Triple Mul(FieldType field, int64_t size, - ReplayDesc* x_desc = nullptr, - ReplayDesc* y_desc = nullptr) = 0; + ReplayDesc* x_desc = nullptr, ReplayDesc* y_desc = nullptr, + ElementType eltype = ElementType::kRing) = 0; + + virtual Pair MulPriv(FieldType field, int64_t size, + ElementType eltype = ElementType::kRing) = 0; virtual Pair Square(FieldType field, int64_t size, ReplayDesc* x_desc = nullptr) = 0; diff --git a/libspu/mpc/semi2k/exp.cc b/libspu/mpc/semi2k/exp.cc new file mode 100644 index 00000000..34dba15f --- /dev/null +++ b/libspu/mpc/semi2k/exp.cc @@ -0,0 +1,97 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "exp.h" + +#include "prime_utils.h" +#include "type.h" + +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +// Given [x*2^fxp] mod 2k for x +// compute [exp(x) * 2^fxp] mod 2^k + +// Assume x is in valid range, otherwise the error may be too large to +// use this method. + +NdArrayRef ExpA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const size_t fxp = ctx->sctx()->getFxpBits(); + SPU_ENFORCE( + fxp < 64, + "fxp must be less than 64 for this method, or shift bit overflow ", + "may occur"); + auto field = in.eltype().as()->field(); + NdArrayRef x = in.clone(); + NdArrayRef out; + + // TODO: set different values for FM64 FM32 + const size_t kExpFxp = (field == FieldType::FM128) ? 24 : 13; + + const int rank = ctx->sctx()->lctx()->Rank(); + DISPATCH_ALL_FIELDS(field, [&]() { + auto total_fxp = kExpFxp + fxp; + // note that x is already encoded with fxp + // this conv scale further converts x int fixed point numbers with + // total_fxp + const ring2k_t exp_conv_scale = std::roundf(M_LOG2E * (1L << kExpFxp)); + + // offset scale should directly encoded to a fixed point with total_fxp + const ring2k_t offset = + ctx->sctx()->config().experimental_exp_prime_offset(); + const ring2k_t offset_scaled = offset << total_fxp; + + NdArrayView _x(x); + if (rank == 0) { + pforeach(0, x.numel(), [&](ring2k_t i) { + _x[i] *= exp_conv_scale; + _x[i] += offset_scaled; + }); + } else { + pforeach(0, x.numel(), [&](ring2k_t i) { _x[i] *= exp_conv_scale; }); + } + size_t shr_width = SizeOf(field) * 8 - fxp; + + const ring2k_t kBit = 1; + auto shifted_bit = kBit << total_fxp; + const ring2k_t frac_mask = shifted_bit - 1; + + auto int_part = ring_arshift(x, {static_cast(total_fxp)}); + + // convert from ring-share (int-part) to a prime share over p - 1 + int_part = ProbConvRing2k(int_part, rank, shr_width); + NdArrayView int_part_view(int_part); + + pforeach(0, x.numel(), [&](int64_t i) { + // y = 2^int_part mod p + ring2k_t y = exp_mod(2, int_part_view[i]); + // z = 2^fract_part in RR + double frac_part = static_cast(_x[i] & frac_mask) / shifted_bit; + frac_part = std::pow(2., frac_part); + + // Multiply the 2^{int_part} * 2^{frac_part} mod p + // note that mul_mod uses mersenne prime as modulus according to field + int_part_view[i] = mul_mod( + y, static_cast(std::roundf(frac_part * (kBit << fxp)))); + }); + + NdArrayRef muled = MulPrivModMP(ctx, int_part.as(makeType(field))); + + out = ConvMP(ctx, muled, offset + fxp); + }); + return out.as(in.eltype()); +} + +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/exp.h b/libspu/mpc/semi2k/exp.h new file mode 100644 index 00000000..fcc4711e --- /dev/null +++ b/libspu/mpc/semi2k/exp.h @@ -0,0 +1,37 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { + +// Given [x*2^fxp] mod 2k for x +// compute [exp(x) * 2^fxp] mod 2^k +// Example: +// spu::mpc::semi2k::ExpA exp; +// outp = exp.proc(&kcontext, ring2k_shr); +class ExpA : public UnaryKernel { + public: + static constexpr const char* kBindName() { return "exp_a"; } + + ce::CExpr latency() const override { return ce::Const(2); } + + ce::CExpr comm() const override { return 2 * ce::K(); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +} // namespace spu::mpc::semi2k diff --git a/libspu/mpc/semi2k/prime_utils.cc b/libspu/mpc/semi2k/prime_utils.cc new file mode 100644 index 00000000..b0911331 --- /dev/null +++ b/libspu/mpc/semi2k/prime_utils.cc @@ -0,0 +1,201 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "prime_utils.h" + +#include "type.h" + +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::semi2k { + +NdArrayRef ProbConvRing2k(const NdArrayRef& inp_share, int rank, + size_t shr_width) { + SPU_ENFORCE(inp_share.eltype().isa()); + SPU_ENFORCE(rank >= 0 && rank <= 1); + + auto eltype = inp_share.eltype(); + NdArrayRef output_share(eltype, inp_share.shape()); + + auto ring_ty = eltype.as()->field(); + uint128_t shifted_bit = 1; + shifted_bit <<= shr_width; + auto mask = shifted_bit - 1; + // x mod p - 1 + // in our case p > 2^shr_width + + DISPATCH_ALL_FIELDS(ring_ty, [&]() { + const auto prime = ScalarTypeToPrime::prime; + ring2k_t prime_minus_one = (prime - 1); + NdArrayView inp(inp_share); + NdArrayView output_share_view(output_share); + pforeach(0, output_share.numel(), [&](int64_t i) { + output_share_view[i] = + rank == 0 ? ((inp[i] & mask) % prime_minus_one) + // numerical considerations here + // we wanted to work on ring 2k or field p - 1 + // however, if we do not add p -1 + // then the computation will resort to int128 + // due to the way computer works + : ((inp[i] & mask) + prime_minus_one - shifted_bit) % + prime_minus_one; + }); + }); + return output_share; +} + +NdArrayRef UnflattenBuffer(yacl::Buffer&& buf, const NdArrayRef& x) { + return NdArrayRef(std::make_shared(std::move(buf)), x.eltype(), + x.shape()); +} + +// P0 holds x,P1 holds y +// Beaver generates ab = c_0 + c_1 +// Give (a, c_0) to P0 +// Give (b, c_1) to P1 +std::tuple MulPrivPrep(KernelEvalContext* ctx, + const NdArrayRef& x) { + const auto field = x.eltype().as()->field(); + auto* beaver = ctx->getState()->beaver(); + + // generate beaver multiple triple. + NdArrayRef a_or_b; + NdArrayRef c; + + const size_t numel = x.shape().numel(); + auto [a_or_b_buf, c_buf] = beaver->MulPriv( + field, numel, // + x.eltype().isa() ? ElementType::kGfmp : ElementType::kRing); + SPU_ENFORCE(static_cast(a_or_b_buf.size()) == numel * SizeOf(field)); + SPU_ENFORCE(static_cast(c_buf.size()) == numel * SizeOf(field)); + + a_or_b = UnflattenBuffer(std::move(a_or_b_buf), x); + c = UnflattenBuffer(std::move(c_buf), x); + + return {std::move(a_or_b), std::move(c)}; +} + +// P0 holds x,P1 holds y +// Beaver generates ab = c_0 + c_1 +// Give (a, c_0) to P0 +// Give (b, c_1) to P1 +// +// - P0 sends (x+a) to P1 ; P1 sends (y+b) to P0 +// - P0 calculates z0 = x(y+b) + c0 ; P1 calculates z1 = -b(x+a) + c1 +NdArrayRef MulPriv(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa()); + auto* comm = ctx->getState(); + + NdArrayRef a_or_b, c, xa_or_yb; + + std::tie(a_or_b, c) = MulPrivPrep(ctx, x); + + // P0 sends (x+a) to P1 ; P1 sends (y+b) to P0 + comm->sendAsync(comm->nextRank(), ring_add(a_or_b, x), "(x + a) or (y + b)"); + xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)"); + // note that our rings are commutative. + if (comm->getRank() == 0) { + ring_add_(c, ring_mul(std::move(xa_or_yb), x)); + } + if (comm->getRank() == 1) { + ring_sub_(c, ring_mul(std::move(xa_or_yb), a_or_b)); + } + return c; +} + +NdArrayRef MulPrivModMP(KernelEvalContext* ctx, const NdArrayRef& x) { + SPU_ENFORCE(x.eltype().isa()); + auto* comm = ctx->getState(); + + NdArrayRef a_or_b, c, xa_or_yb; + std::tie(a_or_b, c) = MulPrivPrep(ctx, x); + + comm->sendAsync(comm->nextRank(), gfmp_add_mod(a_or_b, x), "xa_or_yb"); + xa_or_yb = + comm->recv(comm->prevRank(), x.eltype(), "xa_or_yb").reshape(x.shape()); + + // note that our rings are commutative. + if (comm->getRank() == 0) { + gfmp_add_mod_(c, gfmp_mul_mod(std::move(xa_or_yb), x)); + } + if (comm->getRank() == 1) { + gfmp_sub_mod_(c, gfmp_mul_mod(std::move(xa_or_yb), a_or_b)); + } + return c; +} + +// We assume the input is ``positive'' +// Given h0 + h1 = h mod p and h < p / 2 +// Define b0 = 1{h0 >= p/2} +// b1 = 1{h1 >= p/2} +// Compute w = 1{h0 + h1 >= p} +// It can be proved that w = (b0 or b1) = not (not b0 and not b1) +NdArrayRef WrapBitModMP(KernelEvalContext* ctx, const NdArrayRef& x) { + // create a wrap bit NdArrayRef of the same shape as in + NdArrayRef b(x.eltype(), x.shape()); + + // for each element, we compute b = 1{h < p/2} for each private share piece + const auto numel = x.numel(); + const auto field = x.eltype().as()->field(); + + DISPATCH_ALL_FIELDS(field, [&]() { + ring2k_t prime = ScalarTypeToPrime::prime; + ring2k_t phalf = prime >> 1; + NdArrayView _x(x); + NdArrayView _b(b); + pforeach(0, numel, [&](int64_t idx) { + _b[idx] = static_cast(_x[idx] < phalf); + }); + + // do private mul + b = MulPriv(ctx, b.as(makeType(field))); + + // map 1 to 0 and 0 to 1, use 1 - x + if (ctx->getState()->getRank() == 0) { + pforeach(0, numel, [&](int64_t idx) { _b[idx] = 1 - _b[idx]; }); + } else { + pforeach(0, numel, [&](int64_t idx) { _b[idx] = -_b[idx]; }); + } + }); + + return b; +} +// Mersenne Prime share -> Ring2k share + +NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h, + uint truncate_nbits) { + // calculate wrap bit + NdArrayRef w = WrapBitModMP(ctx, h); + const auto field = h.eltype().as()->field(); + const auto numel = h.numel(); + + // x = (h - p * w) mod 2^k + + NdArrayRef x(makeType(field), h.shape()); + DISPATCH_ALL_FIELDS(field, [&]() { + auto prime = ScalarTypeToPrime::prime; + NdArrayView h_view(h); + NdArrayView _x(x); + NdArrayView w_view(w); + pforeach(0, numel, [&](int64_t idx) { + _x[idx] = static_cast(h_view[idx] >> truncate_nbits) - + static_cast(prime >> truncate_nbits) * w_view[idx]; + }); + }); + return x; +} + +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/prime_utils.h b/libspu/mpc/semi2k/prime_utils.h new file mode 100644 index 00000000..a04acf3a --- /dev/null +++ b/libspu/mpc/semi2k/prime_utils.h @@ -0,0 +1,46 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "libspu/core/context.h" +#include "libspu/mpc/kernel.h" + +namespace spu::mpc::semi2k { +// Ring2k share -> Mersenne Prime - 1 share +// Given x0 + x1 = x mod 2^k +// Compute h0 + h1 = x mod p with probability > 1 - |x|/2^k +NdArrayRef ProbConvRing2k(const NdArrayRef& inp_share, int rank, + size_t shr_width); + +// Mul open private share +std::tuple MulPrivPrep(KernelEvalContext* ctx, + const NdArrayRef& x); + +// Note that [x] = (x_alice, x_bob) and x_alice + x_bob = x +// Note that we actually want to find the muliplication of x_alice and x_bob +// this function is currently achieved by doing (x_alice, 0) * (0, x_bob) +// optimization is possible. +NdArrayRef MulPrivModMP(KernelEvalContext* ctx, const NdArrayRef& x); +// We assume the input is ``positive'' +// Given h0 + h1 = h mod p and h < p / 2 +// Define b0 = 1{h0 >= p/2} +// b1 = 1{h1 >= p/2} +// Compute w = 1{h0 + h1 >= p} +// It can be proved that w = (b0 or b1) +NdArrayRef WrapBitModMP(KernelEvalContext* ctx, const NdArrayRef& x); + +// Mersenne Prime share -> Ring2k share +NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h, + uint truncate_nbits); +} // namespace spu::mpc::semi2k \ No newline at end of file diff --git a/libspu/mpc/semi2k/protocol.cc b/libspu/mpc/semi2k/protocol.cc index 35cd436c..28a2d9c5 100644 --- a/libspu/mpc/semi2k/protocol.cc +++ b/libspu/mpc/semi2k/protocol.cc @@ -20,6 +20,7 @@ #include "libspu/mpc/semi2k/arithmetic.h" #include "libspu/mpc/semi2k/boolean.h" #include "libspu/mpc/semi2k/conversion.h" +#include "libspu/mpc/semi2k/exp.h" #include "libspu/mpc/semi2k/permute.h" #include "libspu/mpc/semi2k/state.h" #include "libspu/mpc/semi2k/type.h" @@ -76,6 +77,12 @@ void regSemi2kProtocol(SPUContext* ctx, if (lctx->WorldSize() == 2) { ctx->prot()->regKernel(); + + // only supports 2pc fm128 for now + if (ctx->getField() == FieldType::FM128 && + ctx->config().experimental_enable_exp_prime()) { + ctx->prot()->regKernel(); + } } // ctx->prot()->regKernel(); } diff --git a/libspu/mpc/semi2k/protocol_test.cc b/libspu/mpc/semi2k/protocol_test.cc index 66911344..eb1a6c60 100644 --- a/libspu/mpc/semi2k/protocol_test.cc +++ b/libspu/mpc/semi2k/protocol_test.cc @@ -25,7 +25,11 @@ #include "libspu/mpc/api_test.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/semi2k/beaver/beaver_impl/ttp_server/beaver_server.h" +#include "libspu/mpc/semi2k/exp.h" +#include "libspu/mpc/semi2k/prime_utils.h" #include "libspu/mpc/semi2k/state.h" +#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/utils/gfmp.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -36,6 +40,12 @@ RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; conf.set_protocol(ProtocolKind::SEMI2K); conf.set_field(field); + if (field == FieldType::FM64) { + conf.set_fxp_fraction_bits(17); + } else if (field == FieldType::FM128) { + conf.set_fxp_fraction_bits(40); + } + conf.set_experimental_enable_exp_prime(true); return conf; } @@ -404,4 +414,173 @@ TEST_P(BeaverCacheTest, SquareA) { }); } +TEST_P(BeaverCacheTest, priv_mul_test) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + // only supports 2 party (not counting beaver) + if (npc != 2) { + return; + } + NdArrayRef ring2k_shr[2]; + + int64_t numel = 1; + FieldType field = conf.field(); + + std::vector real_vec(numel); + for (int64_t i = 0; i < numel; ++i) { + real_vec[i] = 2; + } + + auto rnd_msg = gfmp_zeros(field, {numel}); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView xmsg(rnd_msg); + pforeach(0, numel, [&](int64_t i) { xmsg[i] = std::round(real_vec[i]); }); + }); + + ring2k_shr[0] = rnd_msg; + ring2k_shr[1] = rnd_msg; + + NdArrayRef input, outp_pub; + NdArrayRef outp[2]; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + KernelEvalContext kcontext(obj.get()); + + int rank = lctx->Rank(); + + outp[rank] = spu::mpc::semi2k::MulPrivModMP(&kcontext, ring2k_shr[rank]); + }); + auto got = gfmp_add_mod(outp[0], outp[1]); + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView got_view(got); + + double max_err = 0.0; + double min_err = 99.0; + for (int64_t i = 0; i < numel; ++i) { + double expected = real_vec[i] * real_vec[i]; + double got = static_cast(got_view[i]); + max_err = std::max(max_err, std::abs(expected - got)); + min_err = std::min(min_err, std::abs(expected - got)); + } + ASSERT_LE(min_err, 1e-3); + ASSERT_LE(max_err, 1e-3); + }); +} + +TEST_P(BeaverCacheTest, exp_mod_test) { + const RuntimeConfig& conf = std::get<1>(GetParam()); + FieldType field = conf.field(); + + DISPATCH_ALL_FIELDS(field, [&]() { + // exponents < 32 + ring2k_t exponents[5] = {10, 21, 27}; + + for (ring2k_t exponent : exponents) { + ring2k_t y = exp_mod(2, exponent); + ring2k_t prime = ScalarTypeToPrime::prime; + ring2k_t prime_minus_one = (prime - 1); + ring2k_t shifted_bit = 1; + shifted_bit <<= exponent; + EXPECT_EQ(y, shifted_bit % prime_minus_one); + } + }); +} + +TEST_P(BeaverCacheTest, ExpA) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + // exp only supports 2 party (not counting beaver) + // only supports FM128 for now + // note not using ctx->hasKernel("exp_a") because we are testing kernel + // registration as well. + if (npc != 2 || conf.field() != FieldType::FM128) { + return; + } + auto fxp = conf.fxp_fraction_bits(); + + NdArrayRef ring2k_shr[2]; + + int64_t numel = 100; + FieldType field = conf.field(); + + // how to define and achieve high pricision for e^20 + std::uniform_real_distribution dist(-18.0, 15.0); + std::default_random_engine rd; + std::vector real_vec(numel); + for (int64_t i = 0; i < numel; ++i) { + // make the input a fixed point number, eliminate the fixed point encoding + // error + real_vec[i] = + static_cast(std::round((dist(rd) * (1L << fxp)))) / (1L << fxp); + } + + auto rnd_msg = ring_zeros(field, {numel}); + + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView xmsg(rnd_msg); + pforeach(0, numel, [&](int64_t i) { + xmsg[i] = std::round(real_vec[i] * (1L << fxp)); + }); + }); + + ring2k_shr[0] = ring_rand(field, rnd_msg.shape()) + .as(makeType(field)); + ring2k_shr[1] = ring_sub(rnd_msg, ring2k_shr[0]) + .as(makeType(field)); + + NdArrayRef outp_pub; + NdArrayRef outp[2]; + + utils::simulate(npc, [&](const std::shared_ptr& lctx) { + auto obj = factory(conf, lctx); + + KernelEvalContext kcontext(obj.get()); + + int rank = lctx->Rank(); + + size_t bytes = lctx->GetStats()->sent_bytes; + size_t action = lctx->GetStats()->sent_actions; + + spu::mpc::semi2k::ExpA exp; + outp[rank] = exp.proc(&kcontext, ring2k_shr[rank]); + + bytes = lctx->GetStats()->sent_bytes - bytes; + action = lctx->GetStats()->sent_actions - action; + SPDLOG_INFO("ExpA ({}) for n = {}, sent {} MiB ({} B per), actions {}", + field, numel, bytes * 1. / 1024. / 1024., bytes * 1. / numel, + action); + }); + assert(outp[0].eltype() == ring2k_shr[0].eltype()); + auto got = ring_add(outp[0], outp[1]); + ring_print(got, "exp result"); + DISPATCH_ALL_FIELDS(field, [&]() { + using sT = std::make_signed::type; + NdArrayView got_view(got); + + double max_err = 0.0; + for (int64_t i = 0; i < numel; ++i) { + double expected = std::exp(real_vec[i]); + expected = static_cast(std::round((expected * (1L << fxp)))) / + (1L << fxp); + double got = static_cast(got_view[i]) / (1L << fxp); + // cout left here for future improvement + std::cout << "expected: " << fmt::format("{0:f}", expected) + << ", got: " << fmt::format("{0:f}", got) << std::endl; + std::cout << "expected: " + << fmt::format("{0:b}", + static_cast(expected * (1L << fxp))) + << ", got: " << fmt::format("{0:b}", got_view[i]) << std::endl; + max_err = std::max(max_err, std::abs(expected - got)); + } + ASSERT_LE(max_err, 1e-0); + }); +} } // namespace spu::mpc::test diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index 00287e26..2b494ff1 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -76,6 +76,7 @@ spu_cc_library( deps = [ ":linalg", "//libspu/core:ndarray_ref", + "//libspu/core:type_util", "@yacl//yacl/crypto/rand", "@yacl//yacl/crypto/tools:prg", "@yacl//yacl/utils:parallel", @@ -90,6 +91,36 @@ spu_cc_test( ], ) +spu_cc_library( + name = "gfmp_ops", + srcs = ["gfmp_ops.cc"], + hdrs = ["gfmp_ops.h"], + copts = select({ + "@platforms//cpu:x86_64": [ + "-mavx", + ], + "//conditions:default": [], + }), + deps = [ + ":gfmp", + ":linalg", + ":ring_ops", + "//libspu/core:ndarray_ref", + "@yacl//yacl/crypto/rand", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/utils:parallel", + ], +) + +spu_cc_library( + name = "gfmp", + hdrs = ["gfmp.h"], + deps = [ + "//libspu/core:type_util", + "@yacl//yacl/base:int128", + ], +) + spu_cc_binary( name = "ring_ops_bench", srcs = ["ring_ops_bench.cc"], diff --git a/libspu/mpc/utils/gfmp.h b/libspu/mpc/utils/gfmp.h new file mode 100644 index 00000000..8dfc4f5f --- /dev/null +++ b/libspu/mpc/utils/gfmp.h @@ -0,0 +1,168 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/base/int128.h" + +#include "libspu/core/type_util.h" + +#define EIGEN_HAS_OPENMP + +#include "Eigen/Core" + +namespace spu::mpc { + +inline uint8_t mul(uint8_t x, uint8_t y, uint8_t* z) { + uint16_t hi = static_cast(x) * static_cast(y); + auto lo = static_cast(hi); + if (z != nullptr) { + *z = static_cast(hi >> 8); + } + return lo; +} + +inline uint32_t mul(uint32_t x, uint32_t y, uint32_t* z) { + uint64_t hi = static_cast(x) * static_cast(y); + auto lo = static_cast(hi); + if (z != nullptr) { + *z = static_cast(hi >> 32); + } + return lo; +} + +inline uint64_t mul(uint64_t x, uint64_t y, uint64_t* z) { + uint128_t hi = static_cast(x) * static_cast(y); + auto lo = static_cast(hi); + if (z != nullptr) { + *z = static_cast(hi >> 64); + } + return lo; +} + +inline uint128_t mul(uint128_t x, uint128_t y, uint128_t* z) { + uint64_t x_lo = x & 0xFFFFFFFFFFFFFFFF; + uint64_t x_hi = x >> 64; + uint64_t y_lo = y & 0xFFFFFFFFFFFFFFFF; + uint64_t y_hi = y >> 64; + + uint128_t lo = static_cast(x_lo) * y_lo; + + uint128_t xl_yh = static_cast(x_lo) * y_hi; + uint128_t xh_yl = static_cast(x_hi) * y_lo; + + lo += xl_yh << 64; + uint128_t hi = static_cast(lo < (xl_yh << 64)); + + lo += xh_yl << 64; + hi += static_cast(lo < (xh_yl << 64)); + hi += static_cast(x_hi) * y_hi; + + hi += xl_yh >> 64; + hi += xh_yl >> 64; + if (z != nullptr) { + *z = hi; + } + return lo; +} + +template , bool> = true> +inline T mul_mod(T x, T y) { + T c = 0; + T e = mul(x, y, &c); + T p = ScalarTypeToPrime::prime; + size_t mp_exp = ScalarTypeToPrime::exp; + T ret = (e & p) + ((e >> mp_exp) ^ (c << (sizeof(T) * 8 - mp_exp))); + return (ret >= p) ? ret - p : ret; +} + +template , bool> = true> +inline T add_mod(T x, T y) { + T ret = x + y; + T p = ScalarTypeToPrime::prime; + return (ret >= p) ? ret - p : ret; +} + +template , bool> = true> +inline T add_inv(T x) { + T p = ScalarTypeToPrime::prime; + return x ^ p; +} + +// Extended Euclidean Algorithm +// ax + by = gcd(a, b) +template , bool> = true> +void extend_gcd(T a, T b, T& x, T& y) { + if (b == 0) { + x = 1; + y = 0; + return; + } + extend_gcd(b, static_cast(a % b), y, x); + T tmp = mul_mod(static_cast(a / b), x); + y = add_mod(y, add_inv(tmp)); +} + +template , bool> = true> +inline T mul_inv(T in) { + T x; + T y; + T p = ScalarTypeToPrime::prime; + extend_gcd(p, in, x, y); + return y; +} + +template , bool> = true> +inline T mod_p(T in) { + T p = ScalarTypeToPrime::prime; + size_t mp_exp = ScalarTypeToPrime::exp; + T i = (in & p) + (in >> mp_exp); + return i >= p ? i - p : i; +} + +// the following code references SEAL library +// https://github.com/microsoft/SEAL/blob/main/src/seal/util/uintarithsmallmod.cpp +template , bool> = true> +inline T exp_mod(T operand, T exponent) { + // Fast cases + if (exponent == 0) { + // Result is supposed to be only one digit + return 1; + } + + if (exponent == 1) { + return operand; + } + + // Perform binary exponentiation. + T power = operand; + T product = 0; + T intermediate = 1; + + // Initially: power = operand and intermediate = 1, product is irrelevant. + while (true) { + if (exponent & 1) { + product = mul_mod(power, intermediate); + std::swap(product, intermediate); + } + exponent >>= 1; + if (exponent == 0) { + break; + } + product = mul_mod(power, power); + std::swap(product, power); + } + return intermediate; +} +} // namespace spu::mpc \ No newline at end of file diff --git a/libspu/mpc/utils/gfmp_ops.cc b/libspu/mpc/utils/gfmp_ops.cc new file mode 100644 index 00000000..c336a8ba --- /dev/null +++ b/libspu/mpc/utils/gfmp_ops.cc @@ -0,0 +1,251 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define PFOR_GRAIN_SIZE 4096 + +#include "libspu/mpc/utils/gfmp_ops.h" + +#include + +#include "absl/types/span.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/crypto/tools/prg.h" + +#include "libspu/mpc/utils/gfmp.h" +#include "libspu/mpc/utils/linalg.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc { +namespace { + +#define SPU_ENFORCE_RING(x) \ + SPU_ENFORCE((x).eltype().isa(), "expect ring type, got={}", \ + (x).eltype()); + +#define SPU_ENFORCE_GFMP(x) \ + SPU_ENFORCE((x).eltype().isa(), "expect gfmp type, got={}", \ + (x).eltype()); + +#define ENFORCE_EQ_ELSIZE_AND_SHAPE(lhs, rhs) \ + SPU_ENFORCE((lhs).elsize() == (rhs).elsize(), \ + "type size mismatch lhs={}, rhs={}", (lhs).eltype(), \ + (rhs).eltype()); \ + SPU_ENFORCE((lhs).shape() == (rhs).shape(), \ + "numel mismatch, lhs={}, rhs={}", lhs, rhs); + +// Fast mod operation for Mersenne prime +void gfmp_mod_impl(NdArrayRef& ret, const NdArrayRef& x) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + const auto* ty = ret.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = mod_p(_x[idx]); }); + }); +} + +void gfmp_mul_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, + [&](int64_t idx) { _ret[idx] = mul_mod(_x[idx], _y[idx]); }); + }); +} + +void gfmp_add_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, + [&](int64_t idx) { _ret[idx] = add_mod(_x[idx], _y[idx]); }); + }); +} + +void gfmp_sub_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = add_mod(_x[idx], add_inv(_y[idx])); + }); + }); +} + +void gfmp_div_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, [&](int64_t idx) { + _ret[idx] = mul_mod(_x[idx], mul_inv(_y[idx])); + }); + }); +} + +} // namespace +NdArrayRef gfmp_zeros(FieldType field, const Shape& shape) { + NdArrayRef ret(makeType(field), shape); + auto numel = ret.numel(); + + return DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = 0; }); + return ret; + }); +} +NdArrayRef gfmp_rand(FieldType field, const Shape& shape) { + uint64_t cnt = 0; + return gfmp_rand(field, shape, yacl::crypto::SecureRandSeed(), &cnt); +} + +NdArrayRef gfmp_rand(FieldType field, const Shape& shape, uint128_t prg_seed, + uint64_t* prg_counter) { + constexpr yacl::crypto::SymmetricCrypto::CryptoType kCryptoType = + yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; + constexpr uint128_t kAesInitialVector = 0U; + NdArrayRef res(makeType(field), shape); + DISPATCH_ALL_FIELDS(field, [&]() { + *prg_counter = yacl::crypto::FillPRandWithMersennePrime( + kCryptoType, prg_seed, kAesInitialVector, *prg_counter, + absl::MakeSpan(&res.at(0), res.numel())); + }); + return res; +} + +NdArrayRef gfmp_mod(const NdArrayRef& x) { + SPU_ENFORCE_GFMP(x); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_mod_impl(ret, x); + return ret; +} + +void gfmp_mod_(NdArrayRef& x) { + SPU_ENFORCE_GFMP(x); + gfmp_mod_impl(x, x); +} + +NdArrayRef gfmp_mul_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_mul_mod_impl(ret, x, y); + return ret; +} + +void gfmp_mul_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_mul_mod_impl(x, x, y); +} + +NdArrayRef gfmp_div_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_div_mod_impl(ret, x, y); + return ret; +} + +void gfmp_div_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_div_mod_impl(x, x, y); +} + +NdArrayRef gfmp_add_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_add_mod_impl(ret, x, y); + return ret; +} + +void gfmp_add_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_add_mod_impl(x, x, y); +} + +NdArrayRef gfmp_sub_mod(const NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_sub_mod_impl(ret, x, y); + return ret; +} + +void gfmp_sub_mod_(NdArrayRef& x, const NdArrayRef& y) { + SPU_ENFORCE_GFMP(x); + SPU_ENFORCE_GFMP(y); + gfmp_sub_mod_impl(x, x, y); +} + +// not requiring and not casting field. +void gfmp_exp_mod_impl(NdArrayRef& ret, const NdArrayRef& x, + const NdArrayRef& y) { + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); + const auto* ty = x.eltype().as(); + const auto field = ty->field(); + const auto numel = x.numel(); + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView _ret(ret); + NdArrayView _x(x); + NdArrayView _y(y); + pforeach(0, numel, + [&](int64_t idx) { _ret[idx] = exp_mod(_x[idx], _y[idx]); }); + }); +} + +NdArrayRef gfmp_exp_mod(const NdArrayRef& x, const NdArrayRef& y) { + NdArrayRef ret(x.eltype(), x.shape()); + gfmp_exp_mod_impl(ret, x, y); + return ret; +} + +void gfmp_exp_mod_(NdArrayRef& x, const NdArrayRef& y) { + gfmp_exp_mod_impl(x, x, y); +} + +} // namespace spu::mpc diff --git a/libspu/mpc/utils/gfmp_ops.h b/libspu/mpc/utils/gfmp_ops.h new file mode 100644 index 00000000..31d275e8 --- /dev/null +++ b/libspu/mpc/utils/gfmp_ops.h @@ -0,0 +1,45 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/ndarray_ref.h" + +namespace spu::mpc { + +NdArrayRef gfmp_rand(FieldType field, const Shape& shape); +NdArrayRef gfmp_rand(FieldType field, const Shape& shape, uint128_t prg_seed, + uint64_t* prg_counter); + +NdArrayRef gfmp_zeros(FieldType field, const Shape& shape); + +NdArrayRef gfmp_mod(const NdArrayRef& x); +void gfmp_mod_(NdArrayRef& x); + +NdArrayRef gfmp_mul_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_mul_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_div_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_div_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_add_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_add_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_sub_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_sub_mod_(NdArrayRef& x, const NdArrayRef& y); + +NdArrayRef gfmp_exp_mod(const NdArrayRef& x, const NdArrayRef& y); +void gfmp_exp_mod_(NdArrayRef& x, const NdArrayRef& y); + +} // namespace spu::mpc diff --git a/libspu/spu.proto b/libspu/spu.proto index 93793d83..a9050c5d 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -240,6 +240,7 @@ message RuntimeConfig { EXP_DEFAULT = 0; // Implementation defined. EXP_PADE = 1; // The pade approximation. EXP_TAYLOR = 2; // Taylor series approximation. + EXP_PRIME = 3; // exp prime only available for some implementations } // The exponent approximation method. @@ -331,6 +332,25 @@ message RuntimeConfig { uint64 experimental_inter_op_concurrency = 104; // Enable use of private type bool experimental_enable_colocated_optimization = 105; + + // enable experimental exp prime method + bool experimental_enable_exp_prime = 106; + + // The offset parameter for exp prime methods. + // control the valid range of exp prime method. + // valid range is: + // ((47 - offset - 2fxp)/log_2(e), (125 - 2fxp - offset)/log_2(e)) + // clamp to value would be + // lower bound: (48 - offset - 2fxp)/log_2(e) + // higher bound: (124 - 2fxp - offset)/log_2(e) + // default offset is 13, 0 offset is not supported. + uint32 experimental_exp_prime_offset = 107; + // whether to apply the clamping lower bound + // default to enable it + bool experimental_exp_prime_disable_lower_bound = 108; + // whether to apply the clamping upper bound + // default to disable it + bool experimental_exp_prime_enable_upper_bound = 109; } message TTPBeaverConfig {