diff --git a/yacl/crypto/primitives/ot/ferret_ote.cc b/yacl/crypto/primitives/ot/ferret_ote.cc index 65bd2e72..8726ee23 100644 --- a/yacl/crypto/primitives/ot/ferret_ote.cc +++ b/yacl/crypto/primitives/ot/ferret_ote.cc @@ -75,6 +75,10 @@ OtSendStore FerretOtExtSend(const std::shared_ptr& ctx, YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); + YACL_ENFORCE(ot_num >= FerretCotHelper(lpn_param, ot_num), + "ot_num is {}, which should be much greater than the minium " + "size of base cot {}", + ot_num, FerretCotHelper(lpn_param, ot_num)); // get constants: the number of cot needed for mpcot phase const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n); @@ -152,6 +156,10 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr& ctx, YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact); YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num)); + YACL_ENFORCE(ot_num >= FerretCotHelper(lpn_param, ot_num), + "ot_num is {}, which should be much greater than the minium " + "size of base cot {}", + ot_num, FerretCotHelper(lpn_param, ot_num)); // get constants: the number of cot needed for mpcot phase const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n); diff --git a/yacl/crypto/primitives/ot/ferret_ote_rn.h b/yacl/crypto/primitives/ot/ferret_ote_rn.h index b6db0f41..8d253481 100644 --- a/yacl/crypto/primitives/ot/ferret_ote_rn.h +++ b/yacl/crypto/primitives/ot/ferret_ote_rn.h @@ -39,7 +39,7 @@ void MpCotRNSend(const std::shared_ptr& ctx, // for each bin, call single-point cot for (uint64_t i = 0; i < batch_num; ++i) { const uint64_t this_size = - (i == batch_size - 1) ? full_size - i * batch_size : batch_size; + (i == batch_num - 1) ? full_size - i * batch_size : batch_size; const auto& cot_slice = cot.Slice(i * math::Log2Ceil(this_size), (i + 1) * math::Log2Ceil(this_size)); @@ -58,7 +58,7 @@ void MpCotRNRecv(const std::shared_ptr& ctx, // for each bin, call single-point cot for (uint64_t i = 0; i < batch_num; ++i) { const uint64_t this_size = - (i == batch_size - 1) ? full_size - i * batch_size : batch_size; + (i == batch_num - 1) ? full_size - i * batch_size : batch_size; const auto cot_slice = cot.Slice(i * math::Log2Ceil(this_size), (i + 1) * math::Log2Ceil(this_size)); FerretGywzOtExtRecv(ctx, cot_slice, this_size, diff --git a/yacl/crypto/primitives/ot/ferret_ote_test.cc b/yacl/crypto/primitives/ot/ferret_ote_test.cc index 72a7f58e..cfcd6b49 100644 --- a/yacl/crypto/primitives/ot/ferret_ote_test.cc +++ b/yacl/crypto/primitives/ot/ferret_ote_test.cc @@ -83,4 +83,31 @@ INSTANTIATE_TEST_SUITE_P( FerretParams{1 << 24, LpnNoiseAsm::RegularNoise}, FerretParams{1 << 25, LpnNoiseAsm::RegularNoise})); +TEST(FerretOtExtEdgeTest, Test) { + // GIVEN + const int kWorldSize = 2; + const auto assumption = LpnNoiseAsm::RegularNoise; + + auto lctxs = link::test::SetupWorld(kWorldSize); // setup network + auto lpn_param = LpnParam(10485760, 452000, 1280, assumption); + + // ot_num < minium size of base_cot + const size_t ot_num = FerretCotHelper(lpn_param, 0) - 1; + auto cot_num = FerretCotHelper(lpn_param, ot_num); // make option + auto cots_compact = MockCompactOts(cot_num); // mock cots + + // WHEN + auto sender = std::async([&] { + ASSERT_THROW( + FerretOtExtSend(lctxs[0], cots_compact.send, lpn_param, ot_num), + ::yacl::Exception); + }); + auto receiver = std::async([&] { + ASSERT_THROW( + FerretOtExtRecv(lctxs[1], cots_compact.recv, lpn_param, ot_num), + ::yacl::Exception); + }); + sender.get(); + receiver.get(); +} } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/tpre/BUILD.bazel b/yacl/crypto/primitives/tpre/BUILD.bazel index 8f7e457d..260d8b52 100644 --- a/yacl/crypto/primitives/tpre/BUILD.bazel +++ b/yacl/crypto/primitives/tpre/BUILD.bazel @@ -35,7 +35,6 @@ yacl_cc_library( hdrs = ["hash.h"], deps = [ ":kdf", - "//yacl/base:dynamic_bitset", "//yacl/crypto/base/ecc:spi", "//yacl/crypto/base/hash:hash_utils", "//yacl/math/mpint", diff --git a/yacl/crypto/primitives/tpre/capsule.cc b/yacl/crypto/primitives/tpre/capsule.cc index 4fd19fcf..200ecb20 100644 --- a/yacl/crypto/primitives/tpre/capsule.cc +++ b/yacl/crypto/primitives/tpre/capsule.cc @@ -26,7 +26,6 @@ namespace yacl::crypto { std::pair> Capsule::EnCapsulate( const std::unique_ptr& ecc_group, const Keys::PublicKey& delegating_public_key) const { - MPInt zero_bn(0); MPInt order = ecc_group->GetOrder(); MPInt r; MPInt::RandomLtN(order, &r); @@ -35,25 +34,15 @@ std::pair> Capsule::EnCapsulate( EcPoint E = ecc_group->MulBase(r); EcPoint V = ecc_group->MulBase(u); - std::string E_string_join_V_sting = - std::string(ecc_group->SerializePoint(E)) + - std::string(ecc_group->SerializePoint(V)); - MPInt s = u.AddMod( - r.MulMod(CipherHash(E_string_join_V_sting, ecc_group), order), order); + MPInt s = u.AddMod(r.MulMod(CipherHash({E, V}, ecc_group), order), order); EcPoint K_point = ecc_group->Mul(delegating_public_key.y, u.AddMod(r, order)); - std::string K_string = std::string(ecc_group->SerializePoint(K_point)); - Capsule::CapsuleStruct capsule_struct = {E, V, s}; - std::vector K = KDF(K_string, 16); - - std::pair> capsule_pair; - capsule_pair.first = capsule_struct; - capsule_pair.second = K; + std::vector K = KDF(ecc_group->SerializePoint(K_point), 16); - return capsule_pair; + return {capsule_struct, K}; } // Decapsulate(skA,capsule)->(K) @@ -63,9 +52,8 @@ std::vector Capsule::DeCapsulate( const CapsuleStruct& capsule_struct) const { EcPoint K_point = ecc_group->Mul( ecc_group->Add(capsule_struct.E, capsule_struct.V), private_key.x); - std::string K_string = std::string(ecc_group->SerializePoint(K_point)); - std::vector K = KDF(K_string, 16); + std::vector K = KDF(ecc_group->SerializePoint(K_point), 16); return K; } @@ -76,39 +64,22 @@ std::pair Capsule::CheckCapsule( EcPoint tmp0 = ecc_group->MulBase(capsule_struct.s); // compute H_2(E,V) - std::string E_string_join_V_sting = - std::string(ecc_group->SerializePoint(capsule_struct.E)) + - std::string(ecc_group->SerializePoint(capsule_struct.V)); - MPInt hev = CipherHash(E_string_join_V_sting, ecc_group); + MPInt hev = CipherHash({capsule_struct.E, capsule_struct.V}, ecc_group); EcPoint e_exp_hev = ecc_group->Mul(capsule_struct.E, hev); EcPoint tmp1 = ecc_group->Add(capsule_struct.V, e_exp_hev); - std::string tmp0_string = std::string(ecc_group->SerializePoint(tmp0)); - std::string tmp1_string = std::string(ecc_group->SerializePoint(tmp1)); - - int signal; - if (tmp0_string == tmp1_string) { - signal = 1; - } else { - signal = 0; - } - - std::pair capsule_check_result = {capsule_struct, - signal}; - - return capsule_check_result; + return {capsule_struct, ecc_group->PointEqual(tmp0, tmp1)}; } -// /** -// * Each Re-encryptor generates the ciphertext fragment, i.e., cfrag -// * */ +// +// Each Re-encryptor generates the ciphertext fragment, i.e., cfrag +// Capsule::CFrag Capsule::ReEncapsulate(const std::unique_ptr& ecc_group, const Keys::KFrag& kfrag, const CapsuleStruct& capsule) const { // First checks the validity of the capsule with CheckCapsule and outputs ⊥ // if the check fails. - auto capsule_check_result = CheckCapsule(ecc_group, capsule); YACL_ENFORCE(capsule_check_result.second == 1, @@ -130,18 +101,11 @@ std::vector Capsule::DeCapsulateFrags( const std::unique_ptr& ecc_group, const Keys::PrivateKey& sk_B, const Keys::PublicKey& pk_A, const Keys::PublicKey& pk_B, const std::vector& cfrags) const { - MPInt one_bn(1); - // Compute (pk_B)^a EcPoint pk_A_mul_b = ecc_group->Mul(pk_A.y, sk_B.x); - std::string pk_A_mul_b_str = - std::string(ecc_group->SerializePoint(pk_A_mul_b)); - std::string pk_A_str = std::string(ecc_group->SerializePoint(pk_A.y)); - std::string pk_B_str = std::string(ecc_group->SerializePoint(pk_B.y)); - // 1. Compute D = H_6(pk_A, pk_B, (pk_A)^b) - MPInt D = CipherHash(pk_A_str + pk_B_str + pk_A_mul_b_str, ecc_group); + MPInt D = CipherHash({pk_A.y, pk_B.y, pk_A_mul_b}, ecc_group); // 2. Compute s_{x,i} and lambda_{i,S} // 2.1 Compute s_{x,i} = H_5(id_i, D) @@ -191,20 +155,14 @@ std::vector Capsule::DeCapsulateFrags( // 4. Compute d = H_3(X_A,pk_B,(X_A)^b) std::string X_A_str = std::string(ecc_group->SerializePoint(cfrags[0].X_A)); EcPoint X_A_mul_b = ecc_group->Mul(cfrags[0].X_A, sk_B.x); - std::string X_A_mul_b_str = std::string(ecc_group->SerializePoint(X_A_mul_b)); - MPInt d = CipherHash(X_A_str + pk_B_str + X_A_mul_b_str, ecc_group); + MPInt d = CipherHash({cfrags[0].X_A, pk_B.y, X_A_mul_b}, ecc_group); // 5. Compute DEK, i.e., K=KDF((E'· V')^d) - EcPoint E_prime_add_V_prime = ecc_group->Add(E_prime, V_prime); EcPoint E_prime_add_V_prime_mul_d = ecc_group->Mul(E_prime_add_V_prime, d); - std::string K_string = - std::string(ecc_group->SerializePoint(E_prime_add_V_prime_mul_d)); - - std::vector K = KDF(K_string, 16); - - return K; + return KDF(ecc_group->SerializePoint(E_prime_add_V_prime_mul_d), 16); } + } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/tpre/capsule.h b/yacl/crypto/primitives/tpre/capsule.h index bd96e035..cf0e838d 100644 --- a/yacl/crypto/primitives/tpre/capsule.h +++ b/yacl/crypto/primitives/tpre/capsule.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef YACL_CRYPTO_PRIMITIVES_TPRE_CAPSULE_H_ -#define YACL_CRYPTO_PRIMITIVES_TPRE_CAPSULE_H_ +#pragma once + #include #include #include @@ -36,18 +36,18 @@ namespace yacl::crypto { */ class Capsule { public: - Capsule() {} - ~Capsule() {} + Capsule() = default; + ~Capsule() = default; - /// @brief Capsule for encapsulating data keys, Capsule struct includes 2 - /// elliptic curve point and a big number + // @brief Capsule for encapsulating data keys, Capsule struct includes 2 + // elliptic curve point and a big number struct CapsuleStruct { EcPoint E; // E = g^r, g is the generator of elliptic group EcPoint V; // V = g^u, g is the generator of elliptic group MPInt s; // s = u + r · H(E, V) }; - /// @brief CFrag is the fragment of Capsule after re-encapsulating + // @brief CFrag is the fragment of Capsule after re-encapsulating struct CFrag { EcPoint E_1; // E_1 = E^rk EcPoint V_1; // V_1 = V^rk @@ -55,52 +55,52 @@ class Capsule { EcPoint X_A; // X_A = g^x_A }; - /// @brief EnCapsulate algorithm, generate and capsulate the random data - /// encryption key - /// @param ecc_group - /// @param delegating_public_key - /// @return capsule and data ecnryption key + // @brief EnCapsulate algorithm, generate and capsulate the random data + // encryption key + // @param ecc_group + // @param delegating_public_key + // @return capsule and data ecnryption key std::pair> EnCapsulate( const std::unique_ptr& ecc_group, const Keys::PublicKey& delegating_public_key) const; - /// @brief DeCapsulate algorithm, to obtain the data encryption key - /// @param private_key - /// @param capsule_struct - /// @return data encryption key + // @brief DeCapsulate algorithm, to obtain the data encryption key + // @param private_key + // @param capsule_struct + // @return data encryption key std::vector DeCapsulate(const std::unique_ptr& ecc_group, const Keys::PrivateKey& private_key, const CapsuleStruct& capsule_struct) const; - /// @brief Capsule check algorithm - /// @param ecc_group - /// @param capsule_struct - /// @return 0 (check fail) or 1 (check success) + // @brief Capsule check algorithm + // @param ecc_group + // @param capsule_struct + // @return 0 (check fail) or 1 (check success) std::pair CheckCapsule( const std::unique_ptr& ecc_group, const CapsuleStruct& capsule_struct) const; - /// @brief Re-encapsulate capsule - /// @param ecc_group - /// @param kfrag, re-encryption key fragment - /// @param capsule - /// @return Re-encapsulated capsule + // @brief Re-encapsulate capsule + // @param ecc_group + // @param kfrag, re-encryption key fragment + // @param capsule + // @return Re-encapsulated capsule CFrag ReEncapsulate(const std::unique_ptr& ecc_group, const Keys::KFrag& kfrag, const CapsuleStruct& capsule) const; - /// @brief Restore the re-encapsulated capsule set to data encryption key - /// @param ecc_group - /// @param sk_B, secret key of Bob - /// @param pk_A, public key of Alice - /// @param pk_B, public key of Bob - /// @param cfrags, re-encapsulated capsule set - /// @return Data encryption key + // @brief Restore the re-encapsulated capsule set to data encryption key + // @param ecc_group + // @param sk_B, secret key of Bob + // @param pk_A, public key of Alice + // @param pk_B, public key of Bob + // @param cfrags, re-encapsulated capsule set + // @return Data encryption key std::vector DeCapsulateFrags( const std::unique_ptr& ecc_group, const Keys::PrivateKey& sk_B, const Keys::PublicKey& pk_A, const Keys::PublicKey& pk_B, const std::vector& cfrags) const; }; + } // namespace yacl::crypto -#endif // YACL_CRYPTO_PRIMITIVES_TPRE_CAPSULE_H_ diff --git a/yacl/crypto/primitives/tpre/hash.cc b/yacl/crypto/primitives/tpre/hash.cc index 7fd8e01d..6b7eb79e 100644 --- a/yacl/crypto/primitives/tpre/hash.cc +++ b/yacl/crypto/primitives/tpre/hash.cc @@ -21,7 +21,6 @@ #include #include -#include "yacl/base/dynamic_bitset.h" #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/crypto/primitives/tpre/kdf.h" @@ -30,18 +29,34 @@ namespace yacl::crypto { // where n is the degree of EC Group, and x is input MPInt CipherHash(ByteContainerView input, const std::unique_ptr& ecc_group) { - std::array hash_value_0 = Sm3(input); - std::array hash_value_1 = Sm3(hash_value_0); + auto hash_value_0 = Sm3(input); + auto hash_value_1 = Sm3(hash_value_0); - dynamic_bitset binary; - binary.append(hash_value_0.begin(), hash_value_0.end()); - binary.append(hash_value_1.begin(), hash_value_1.end()); - MPInt hash_bn(binary.to_string(), 2); + std::vector buf; + buf.insert(buf.end(), hash_value_0.begin(), hash_value_0.end()); + buf.insert(buf.end(), hash_value_1.begin(), hash_value_1.end()); + + MPInt hash_bn; + hash_bn.FromMagBytes(buf); - MPInt one_bn(1); // h_x = 1 + Bignum(sm3(x)||sm3(sm3(x))) mod n-1 - MPInt h_x = one_bn.AddMod(hash_bn, ecc_group->GetOrder() - one_bn); + MPInt h_x = hash_bn.AddMod(1_mp, ecc_group->GetOrder() - 1_mp); return h_x; } + +MPInt CipherHash(std::initializer_list inputs, + const std::unique_ptr& ecc_group) { + auto len = ecc_group->GetSerializeLength(); + Buffer buf(len * inputs.size()); + + uint8_t index = 0; + for (const auto& p : inputs) { + ecc_group->SerializePoint(p, buf.data() + index * len, len); + index++; + } + + return CipherHash(buf, ecc_group); +} + } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/tpre/hash.h b/yacl/crypto/primitives/tpre/hash.h index 1e1cd845..6c25f9e9 100644 --- a/yacl/crypto/primitives/tpre/hash.h +++ b/yacl/crypto/primitives/tpre/hash.h @@ -12,22 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef YACL_CRYPTO_PRIMITIVES_TPRE_HASH_H_ -#define YACL_CRYPTO_PRIMITIVES_TPRE_HASH_H_ +#pragma once #include -#include "yacl/crypto/base/ecc/ec_point.h" #include "yacl/crypto/base/ecc/ecc_spi.h" namespace yacl::crypto { -/// @brief Cryptographic hash function, h_x = 1 + Bignum(sm3(x)||sm3(sm3(x))), -/// where n is the degree of EC Group, and x is input mod n-1 -/// @param input -/// @param curve_id, elliptic curve type -/// @return hash value +// @brief Cryptographic hash function, h_x = 1 + Bignum(sm3(x)||sm3(sm3(x))), +// where n is the degree of EC Group, and x is input mod n-1 +// @param input +// @param curve_id, elliptic curve type +// @return hash value MPInt CipherHash(ByteContainerView input, const std::unique_ptr& ecc_group); + +MPInt CipherHash(std::initializer_list input, + const std::unique_ptr& ecc_group); + } // namespace yacl::crypto -#endif // YACL_CRYPTO_PRIMITIVES_TPRE_HASH_H_ diff --git a/yacl/crypto/primitives/tpre/hash_test.cc b/yacl/crypto/primitives/tpre/hash_test.cc index 7a6578d7..b37eca04 100644 --- a/yacl/crypto/primitives/tpre/hash_test.cc +++ b/yacl/crypto/primitives/tpre/hash_test.cc @@ -22,15 +22,21 @@ namespace yacl::crypto::test { TEST(HashTest, Test1) { - MPInt zero(0); std::unique_ptr ecc_group = EcGroupFactory::Instance().Create("sm2"); auto hash_value = CipherHash("tpre", ecc_group); - - std::cout << "hash_value = " << hash_value.ToHexString() << std::endl; - EXPECT_TRUE(hash_value > zero); + EXPECT_TRUE(hash_value > 0_mp); EXPECT_EQ(hash_value.ToHexString(), "3532674C20DA7E34FE48093538D7E4167E3C39472B19EBACE593579EA6073329"); + + auto hash_value2 = CipherHash({ecc_group->GetGenerator()}, ecc_group); + EXPECT_EQ(hash_value2.ToHexString(), + "2FE6D05F44F7387077FE1ACECC457BBE3D208C513CAA94FDBA3B58C691D84F21"); + + auto hash_value3 = CipherHash( + {ecc_group->GetGenerator(), ecc_group->MulBase(2_mp)}, ecc_group); + EXPECT_EQ(hash_value3.ToHexString(), + "2A37F6D7231C9CC72D8B8FBEF9A859992B9BDADAC1BDB9E73D881967EB145854"); } } // namespace yacl::crypto::test diff --git a/yacl/crypto/primitives/tpre/kdf.h b/yacl/crypto/primitives/tpre/kdf.h index f31f4b73..80806988 100644 --- a/yacl/crypto/primitives/tpre/kdf.h +++ b/yacl/crypto/primitives/tpre/kdf.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef YACL_CRYPTO_PRIMITIVES_TPRE_KDF_H_ -#define YACL_CRYPTO_PRIMITIVES_TPRE_KDF_H_ +#pragma once + #include #include @@ -21,15 +21,13 @@ namespace yacl::crypto { -/// @brief The KDF structure is derived from the official document of SM2, i.e., -/// "Public Key Cryptographic Algorithm SM2 Based on Elliptic Curves", -/// reference: -/// http://www.sca.gov.cn/sca/xwdt/2010-12/17/1002386/files/b791a9f908bb4803875ab6aeeb7b4e03.pdf -/// @param Z, a random value -/// @param key_len, the key length -/// @return key +// @brief The KDF structure is derived from the official document of SM2, i.e., +// "Public Key Cryptographic Algorithm SM2 Based on Elliptic Curves", +// reference: +// http://www.sca.gov.cn/sca/xwdt/2010-12/17/1002386/files/b791a9f908bb4803875ab6aeeb7b4e03.pdf +// @param Z, a random value +// @param key_len, the key length +// @return key std::vector KDF(ByteContainerView Z, size_t key_len); } // namespace yacl::crypto - -#endif // YACL_CRYPTO_PRIMITIVES_TPRE_KDF_H_ diff --git a/yacl/crypto/primitives/tpre/kdf_test.cc b/yacl/crypto/primitives/tpre/kdf_test.cc index 8a7f8e5a..3ae38b17 100644 --- a/yacl/crypto/primitives/tpre/kdf_test.cc +++ b/yacl/crypto/primitives/tpre/kdf_test.cc @@ -17,16 +17,13 @@ #include #include -#include "absl/strings/escaping.h" #include "gtest/gtest.h" namespace yacl::crypto::test { TEST(KDFTest, Test1) { std::vector key = KDF("key_str", 16); - std::string key_str = absl::BytesToHexString( - absl::string_view((const char*)key.data(), key.size())); - - EXPECT_EQ(key_str, "93a42c6b4c02ab6956f0095787c67e5e"); + EXPECT_EQ(fmt::format("{:02x}", fmt::join(key, "")), + "93a42c6b4c02ab6956f0095787c67e5e"); } } // namespace yacl::crypto::test diff --git a/yacl/crypto/primitives/tpre/keys.cc b/yacl/crypto/primitives/tpre/keys.cc index fe3e2936..47d88b2f 100644 --- a/yacl/crypto/primitives/tpre/keys.cc +++ b/yacl/crypto/primitives/tpre/keys.cc @@ -20,36 +20,28 @@ namespace yacl::crypto { std::pair Keys::GenerateKeyPair( const std::unique_ptr& ecc_group) const { - EcPoint g = ecc_group->GetGenerator(); - // sample random from ecc group - MPInt max = ecc_group->GetOrder(); + MPInt ecc_group_order = ecc_group->GetOrder(); MPInt x; - MPInt::RandomLtN(max, &x); + MPInt::RandomLtN(ecc_group_order, &x); // compute y = g^x EcPoint y = ecc_group->MulBase(x); // Assign private key Keys::PrivateKey private_key = {x}; // Assign public key - Keys::PublicKey public_key = {g, y}; + Keys::PublicKey public_key = {ecc_group->GetGenerator(), y}; - std::pair key_pair; - key_pair.first = public_key; - key_pair.second = private_key; - return key_pair; + return {public_key, private_key}; } -// // Generates re-ecnryption key +// Generates re-ecnryption key std::vector Keys::GenerateReKey( const std::unique_ptr& ecc_group, const PrivateKey& sk_A, const PublicKey& pk_A, const PublicKey& pk_B, int N, int t) const { - MPInt zero_bn(0); - MPInt one_bn(1); MPInt ecc_group_order = ecc_group->GetOrder(); // 1. Select x_A randomly and calculation X_ A=g^{x_A} - MPInt max = ecc_group_order; MPInt x_A; MPInt::RandomLtN(ecc_group_order, &x_A); @@ -58,13 +50,7 @@ std::vector Keys::GenerateReKey( // 2. Compute d = H_3(X_A, pk_B, (pk_B)^{X_A}), where d is the result of a // non-interactive Diffie-Hellman key exchange between B's keypair and the // ephemeral key pair (x_A, X_A). - - std::string pk_B_str = std::string(ecc_group->SerializePoint(pk_B.y)); - std::string pk_B_mul_x_A_str = - std::string(ecc_group->SerializePoint(ecc_group->Mul(pk_B.y, x_A))); - std::string X_A_str = std::string(ecc_group->SerializePoint(X_A)); - - MPInt d = CipherHash(X_A_str + pk_B_str + pk_B_mul_x_A_str, ecc_group); + MPInt d = CipherHash({X_A, pk_B.y, ecc_group->Mul(pk_B.y, x_A)}, ecc_group); // 3. Generate random polynomial coefficients {f_1,...,f_{t-1}} and calculate // coefficients f_ 0 @@ -78,18 +64,15 @@ std::vector Keys::GenerateReKey( for (int i = 1; i <= t - 1; i++) { // Here, t-1 coefficients f_1,...,f_{t-1} are randomly generated. MPInt f_i; - MPInt::RandomLtN(max, &f_i); + MPInt::RandomLtN(ecc_group_order, &f_i); coefficients.push_back(f_i); } // 4. Generate a polynomial via coefficient // 5. Compute D=H_6(pk_A, pk_B, pk^{a}_{B}), where a is the secret key of A - std::string pk_A_str = std::string(ecc_group->SerializePoint(pk_A.y)); - std::string pk_B_mul_a_str = - std::string(ecc_group->SerializePoint(ecc_group->Mul(pk_B.y, sk_A.x))); - - MPInt D = CipherHash(pk_A_str + pk_B_str + pk_B_mul_a_str, ecc_group); + MPInt D = + CipherHash({pk_A.y, pk_B.y, ecc_group->Mul(pk_B.y, sk_A.x)}, ecc_group); // 6. Compute KFrags @@ -104,16 +87,16 @@ std::vector Keys::GenerateReKey( std::vector kfrags; MPInt r_tmp_0; - MPInt::RandomLtN(max, &r_tmp_0); + MPInt::RandomLtN(ecc_group_order, &r_tmp_0); EcPoint U = ecc_group->MulBase(r_tmp_0); // Cycle to generate each element of kfrags for (int i = 0; i <= N - 1; i++) { MPInt r_tmp_1; - MPInt::RandomLtN(max, &r_tmp_1); + MPInt::RandomLtN(ecc_group_order, &r_tmp_1); y.push_back(r_tmp_1); MPInt r_tmp_2; - MPInt::RandomLtN(max, &r_tmp_2); + MPInt::RandomLtN(ecc_group_order, &r_tmp_2); id.push_back(r_tmp_2); s_x.push_back(CipherHash(id[i].ToString() + D.ToString(), ecc_group)); @@ -122,8 +105,8 @@ std::vector Keys::GenerateReKey( // Compute polynomial to obtain rk[i] MPInt rk_tmp = coefficients[0]; - MPInt s_x_exp_j = zero_bn; - MPInt coeff_mul_s_x_exp_j = zero_bn; + MPInt s_x_exp_j = 0_mp; + MPInt coeff_mul_s_x_exp_j = 0_mp; for (int j = 1; j <= t - 1; j++) { s_x_exp_j = s_x[i]; @@ -156,4 +139,5 @@ std::vector Keys::GenerateReKey( return kfrags; } + } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/tpre/keys.h b/yacl/crypto/primitives/tpre/keys.h index b23871bb..f48365bd 100644 --- a/yacl/crypto/primitives/tpre/keys.h +++ b/yacl/crypto/primitives/tpre/keys.h @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef YACL_CRYPTO_PRIMITIVES_TPRE_KEYS_H_ -#define YACL_CRYPTO_PRIMITIVES_TPRE_KEYS_H_ +#pragma once #include #include @@ -44,8 +43,8 @@ namespace yacl::crypto { */ class Keys { public: - Keys() {} - ~Keys() {} + Keys() = default; + ~Keys() = default; /// @brief public key struct struct PublicKey { @@ -88,6 +87,5 @@ class Keys { const PublicKey& pk_A, const PublicKey& pk_B, int N, int t) const; }; -} // namespace yacl::crypto -#endif // YACL_CRYPTO_PRIMITIVES_TPRE_KEYS_H_ +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/tpre/keys_test.cc b/yacl/crypto/primitives/tpre/keys_test.cc index 3c3e8373..0818ae45 100644 --- a/yacl/crypto/primitives/tpre/keys_test.cc +++ b/yacl/crypto/primitives/tpre/keys_test.cc @@ -24,7 +24,6 @@ namespace yacl::crypto::test { TEST(KeyTest, Test1) { - MPInt zero(0); std::unique_ptr ecc_group = EcGroupFactory::Instance().Create("sm2"); Keys keys; @@ -62,7 +61,7 @@ TEST(KeyTest, Test1) { key_pair_bob.first, 5, 4); for (int i = 0; i < 5; i++) { - EXPECT_TRUE(kfrags[i].id > zero); + EXPECT_TRUE(kfrags[i].id > 0_mp); } } } // namespace yacl::crypto::test diff --git a/yacl/crypto/primitives/tpre/tpre.h b/yacl/crypto/primitives/tpre/tpre.h index ba1cb15b..7a98e5d4 100644 --- a/yacl/crypto/primitives/tpre/tpre.h +++ b/yacl/crypto/primitives/tpre/tpre.h @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef YACL_CRYPTO_PRIMITIVES_TPRE_TPRE_H_ -#define YACL_CRYPTO_PRIMITIVES_TPRE_TPRE_H_ +#pragma once #include #include @@ -114,6 +113,5 @@ class TPRE { const std::pair, std::vector>& C_prime_set) const; }; -} // namespace yacl::crypto -#endif // YACL_CRYPTO_PRIMITIVES_TPRE_TPRE_H_ +} // namespace yacl::crypto diff --git a/yacl/link/BUILD.bazel b/yacl/link/BUILD.bazel index 0bd91671..774394f6 100644 --- a/yacl/link/BUILD.bazel +++ b/yacl/link/BUILD.bazel @@ -13,6 +13,8 @@ # limitations under the License. load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) @@ -40,6 +42,9 @@ yacl_cc_library( yacl_cc_library( name = "ssl_options", hdrs = ["ssl_options.h"], + deps = [ + ":link_cc_proto", + ], ) yacl_cc_library( @@ -47,6 +52,7 @@ yacl_cc_library( srcs = ["context.cc"], hdrs = ["context.h"], deps = [ + ":link_cc_proto", ":ssl_options", ":trace", "//yacl/base:byte_container_view", @@ -115,3 +121,15 @@ yacl_cc_library( "@com_github_fmtlib_fmt//:fmtlib", ], ) + +proto_library( + name = "link_proto", + srcs = ["link.proto"], +) + +cc_proto_library( + name = "link_cc_proto", + deps = [ + ":link_proto", + ], +) diff --git a/yacl/link/context.cc b/yacl/link/context.cc index 9215b38c..2edb1227 100644 --- a/yacl/link/context.cc +++ b/yacl/link/context.cc @@ -89,6 +89,12 @@ Context::Context(ContextDesc desc, size_t rank, stats_ = std::make_shared(); } +Context::Context(const ContextDescProto& desc_pb, size_t rank, + std::vector> channels, + std::shared_ptr msg_loop, + bool is_sub_world) + : Context(ContextDesc(), rank, channels, msg_loop, is_sub_world) {} + std::string Context::Id() const { return desc_.id; } size_t Context::WorldSize() const { return desc_.parties.size(); } diff --git a/yacl/link/context.h b/yacl/link/context.h index 6db5f999..4223863e 100644 --- a/yacl/link/context.h +++ b/yacl/link/context.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -27,10 +28,24 @@ #include "yacl/link/transport/channel.h" #include "yacl/utils/hash.h" +#include "yacl/link/link.pb.h" + namespace yacl::link { constexpr size_t kAllRank = std::numeric_limits::max(); + struct ContextDesc { + static constexpr char kDefaultId[] = "root"; + static constexpr uint32_t kDefaultConnectRetryTimes = 10; + static constexpr uint32_t kDefaultConnectRetryIntervalMs = 1000; // 1 second. + static constexpr uint64_t kDefaultRecvTimeoutMs = 30 * 1000; // 30s + static constexpr uint32_t kDefaultHttpMaxPayloadSize = + 1024 * 1024; // 1M Bytes + static constexpr uint32_t kDefaultHttpTimeoutMs = 20 * 1000; // 20 seconds. + static constexpr uint32_t kDefaultThrottleWindowSize = 10; + static constexpr char kDefaultBrpcChannelProtocol[] = "baidu_std"; + static constexpr char kDefaultLinkType[] = "normal"; + struct Party { std::string id; std::string host; @@ -38,19 +53,27 @@ struct ContextDesc { bool operator==(const Party& p) const { return (id == p.id) && (host == p.host); } + + Party() = default; + + Party(const PartyProto& pb) : id(pb.id()), host(pb.host()) {} + + Party(const std::string& id_, const std::string& host_) + : id(id_), host(host_) {} }; // the UUID of this communication. - std::string id = "root"; + std::string id = kDefaultId; // party description, describes the world. std::vector parties; // connect to mesh retry time. - uint32_t connect_retry_times = 10; + uint32_t connect_retry_times = kDefaultConnectRetryTimes; // connect to mesh retry interval. - uint32_t connect_retry_interval_ms = 1000; // 1 second. + uint32_t connect_retry_interval_ms = + kDefaultConnectRetryIntervalMs; // 1 second. // recv timeout in milliseconds. // @@ -69,24 +92,24 @@ struct ContextDesc { // // so for long time work(that one party may wait for the others for very long // time), this value should be changed accordingly. - uint64_t recv_timeout_ms = 30 * 1000; // 30s + uint64_t recv_timeout_ms = kDefaultRecvTimeoutMs; // 30s // http max payload size, if a single http request size is greater than this // limit, it will be unpacked into small chunks then reassembled. // // This field does affect performance. Please choose wisely. - uint32_t http_max_payload_size = 1024 * 1024; // 1M Bytes + uint32_t http_max_payload_size = kDefaultHttpMaxPayloadSize; // 1M Bytes // a single http request timetout. - uint32_t http_timeout_ms = 20 * 1000; // 20 seconds. + uint32_t http_timeout_ms = kDefaultHttpTimeoutMs; // 20 seconds. // throttle window size for channel. if there are more than limited size // messages are flying, `SendAsync` will block until messages are processed or // throw exception after wait for `recv_timeout_ms` - uint32_t throttle_window_size = 10; + uint32_t throttle_window_size = kDefaultThrottleWindowSize; // BRPC client channel protocol. - std::string brpc_channel_protocol = "baidu_std"; + std::string brpc_channel_protocol = kDefaultBrpcChannelProtocol; // BRPC client channel connection type. std::string brpc_channel_connection_type = ""; @@ -107,11 +130,44 @@ struct ContextDesc { bool exit_if_async_error = true; // "blackbox" or "normal", default: "normal" - std::string link_type = "normal"; + std::string link_type = kDefaultLinkType; bool operator==(const ContextDesc& other) const { return (id == other.id) && (parties == other.parties); } + + ContextDesc() = default; + + ContextDesc(const ContextDescProto& pb) + : id(pb.id().size() ? pb.id() : kDefaultId), + connect_retry_times(pb.connect_retry_times() + ? pb.connect_retry_times() + : kDefaultConnectRetryTimes), + connect_retry_interval_ms(pb.connect_retry_interval_ms() + ? pb.connect_retry_interval_ms() + : kDefaultConnectRetryIntervalMs), + recv_timeout_ms(pb.recv_timeout_ms() ? pb.recv_timeout_ms() + : kDefaultRecvTimeoutMs), + http_max_payload_size(pb.http_max_payload_size() + ? pb.http_max_payload_size() + : kDefaultHttpMaxPayloadSize), + http_timeout_ms(pb.http_timeout_ms() ? pb.http_timeout_ms() + : kDefaultHttpTimeoutMs), + throttle_window_size(pb.throttle_window_size() + ? pb.throttle_window_size() + : kDefaultThrottleWindowSize), + brpc_channel_protocol(pb.brpc_channel_protocol().size() + ? pb.brpc_channel_protocol() + : kDefaultBrpcChannelProtocol), + brpc_channel_connection_type(pb.brpc_channel_connection_type()), + enable_ssl(pb.enable_ssl()), + client_ssl_opts(pb.client_ssl_opts()), + server_ssl_opts(pb.server_ssl_opts()), + link_type(kDefaultLinkType) { + for (const auto& party_pb : pb.parties()) { + parties.emplace_back(party_pb); + } + } }; struct ContextDescHasher { @@ -159,6 +215,11 @@ class Context { std::shared_ptr msg_loop, bool is_sub_world = false); + Context(const ContextDescProto& desc_pb, size_t rank, + std::vector> channels, + std::shared_ptr msg_loop, + bool is_sub_world = false); + std::string Id() const; size_t WorldSize() const; diff --git a/yacl/link/context_test.cc b/yacl/link/context_test.cc index 6f092d9d..0dc6f953 100644 --- a/yacl/link/context_test.cc +++ b/yacl/link/context_test.cc @@ -19,6 +19,7 @@ #include "fmt/format.h" #include "gmock/gmock.h" +#include "google/protobuf/util/json_util.h" #include "gtest/gtest.h" #include "yacl/base/byte_container_view.h" @@ -26,6 +27,8 @@ #include "yacl/link/factory.h" #include "yacl/link/transport/channel_mem.h" +#include "yacl/link/link.pb.h" + namespace yacl::link::test { class MockChannel : public transport::IChannel { @@ -278,4 +281,66 @@ TEST(EnvInfo, get_party_node_info) { EXPECT_EQ(parties.size(), 2); } +TEST(ContextDesc, construct_from_pb) { + ContextDescProto pb; + std::string json = R"json( + { + "parties": [ + { + "id": "alice", + "host": "1.2.3.4:1000" + }, + { + "id": "bob", + "host": "1.2.3.5:2000" + } + ], + "connect_retry_times": 15, + "recv_timeout_ms": 20000, + "brpc_channel_protocol": "thrift", + "brpc_channel_connection_type":"single", + "enable_ssl": true, + "client_ssl_opts": { + "certificate_path": "certificate_path/alice", + "private_key_path": "private_key_path/alice", + "verify_depth": 1, + "ca_file_path": "ca_file_path/alice" + }, + "server_ssl_opts": { + "certificate_path": "certificate_path/bob", + "private_key_path": "private_key_path/bob", + "verify_depth": 1, + "ca_file_path": "ca_file_path/bob" + } + })json"; + + EXPECT_TRUE(google::protobuf::util::JsonStringToMessage(json, &pb).ok()); + + ContextDesc desc(pb); + + EXPECT_EQ(desc.id, ContextDesc::kDefaultId); + EXPECT_EQ(desc.parties.size(), 2); + EXPECT_EQ(desc.parties[0].id, "alice"); + EXPECT_EQ(desc.parties[1].host, "1.2.3.5:2000"); + EXPECT_EQ(desc.connect_retry_times, 15); + EXPECT_EQ(desc.connect_retry_interval_ms, + ContextDesc::kDefaultConnectRetryIntervalMs); + EXPECT_EQ(desc.recv_timeout_ms, 20000); + EXPECT_EQ(desc.http_max_payload_size, + ContextDesc::kDefaultHttpMaxPayloadSize); + EXPECT_EQ(desc.http_timeout_ms, ContextDesc::kDefaultHttpTimeoutMs); + EXPECT_EQ(desc.throttle_window_size, ContextDesc::kDefaultThrottleWindowSize); + EXPECT_EQ(desc.brpc_channel_protocol, "thrift"); + EXPECT_EQ(desc.brpc_channel_connection_type, "single"); + EXPECT_EQ(desc.enable_ssl, true); + EXPECT_EQ(desc.client_ssl_opts.cert.certificate_path, + "certificate_path/alice"); + EXPECT_EQ(desc.client_ssl_opts.cert.private_key_path, + "private_key_path/alice"); + EXPECT_EQ(desc.server_ssl_opts.verify.verify_depth, 1); + EXPECT_EQ(desc.server_ssl_opts.verify.ca_file_path, "ca_file_path/bob"); + EXPECT_EQ(desc.exit_if_async_error, true); + EXPECT_EQ(desc.link_type, ContextDesc::kDefaultLinkType); +} + } // namespace yacl::link::test diff --git a/yacl/link/factory_brpc_blackbox.cc b/yacl/link/factory_brpc_blackbox.cc index 9af79da4..20e13a02 100644 --- a/yacl/link/factory_brpc_blackbox.cc +++ b/yacl/link/factory_brpc_blackbox.cc @@ -65,7 +65,7 @@ void FactoryBrpcBlackBox::GetPartyNodeInfoFromEnv( kSelfPartyKey, self_party_id, kNodeInfoPrefix); self_rank = std::distance(party_info.begin(), iter); for (auto const& [party_id, node_id] : party_info) { - parties.emplace_back(ContextDesc::Party{.id = party_id, .host = node_id}); + parties.emplace_back(ContextDesc::Party(party_id, node_id)); } } diff --git a/yacl/link/factory_test.cc b/yacl/link/factory_test.cc index 9aa18251..c91a22c5 100644 --- a/yacl/link/factory_test.cc +++ b/yacl/link/factory_test.cc @@ -33,10 +33,8 @@ class FactoryTest : public ::testing::Test { contexts_.resize(2); ContextDesc desc; desc.id = fmt::format("world_{}", desc_count++); - desc.parties.push_back( - ContextDesc::Party{.id = "alice", .host = "127.0.0.1:63927"}); - desc.parties.push_back( - ContextDesc::Party{.id = "bob", .host = "127.0.0.1:63921"}); + desc.parties.push_back(ContextDesc::Party("alice", "127.0.0.1:63927")); + desc.parties.push_back(ContextDesc::Party("bob", "127.0.0.1:63921")); auto create_brpc = [&](int self_rank) { contexts_[self_rank] = M().CreateContext(desc, self_rank); diff --git a/yacl/link/link.proto b/yacl/link/link.proto new file mode 100644 index 00000000..1f111ad7 --- /dev/null +++ b/yacl/link/link.proto @@ -0,0 +1,106 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package yacl.link; + +message PartyProto { + string id = 1; + + string host = 2; +} + +// SSL options. +message SSLOptionsProto { + // Certificate file path + string certificate_path = 1; + + // Private key file path + string private_key_path = 2; + + // Set the maximum depth of the certificate chain for verification + // If 0, turn off the verification + int32 verify_depth = 3; + // Set the trusted CA file to verify the peer's certificate + // If empty, use the system default CA files + string ca_file_path = 4; +} + +// Configuration for link config. +message ContextDescProto { + // the UUID of this communication. + // optional + string id = 1; + + // party description, describes the world. + repeated PartyProto parties = 2; + + // connect to mesh retry time. + uint32 connect_retry_times = 3; + + // connect to mesh retry interval. + uint32 connect_retry_interval_ms = 4; + + // recv timeout in milliseconds. + // + // 'recv time' is the max time that a party will wait for a given event. + // for example: + // + // begin recv end recv + // |--------|-------recv-time----------|------------------| alice's timeline + // + // begin send end send + // |-----busy-work-------------|-------------|------------| bob's timeline + // + // in above case, when alice begins recv for a specific event, bob is still + // busy doing its job, when alice's wait time exceed wait_timeout_ms, it raise + // exception, although bob now is starting to send data. + // + // so for long time work(that one party may wait for the others for very long + // time), this value should be changed accordingly. + uint64 recv_timeout_ms = 5; + + // http max payload size, if a single http request size is greater than this + // limit, it will be unpacked into small chunks then reassembled. + // + // This field does affect performance. Please choose wisely. + uint32 http_max_payload_size = 6; + + // a single http request timetout. + uint32 http_timeout_ms = 7; + + // throttle window size for channel. if there are more than limited size + // messages are flying, `SendAsync` will block until messages are processed or + // throw exception after wait for `recv_timeout_ms` + uint32 throttle_window_size = 8; + + // BRPC client channel protocol. + string brpc_channel_protocol = 9; + + // BRPC client channel connection type. + string brpc_channel_connection_type = 10; + + // ssl options for link channel. + bool enable_ssl = 11; + + // ssl configs for channel + // this config is ignored if enable_ssl == false; + SSLOptionsProto client_ssl_opts = 12; + + // ssl configs for service + // this config is ignored if enable_ssl == false; + SSLOptionsProto server_ssl_opts = 13; +} diff --git a/yacl/link/ssl_options.h b/yacl/link/ssl_options.h index 3d0a4e1c..38083b42 100644 --- a/yacl/link/ssl_options.h +++ b/yacl/link/ssl_options.h @@ -16,6 +16,8 @@ #include +#include "yacl/link/link.pb.h" + namespace yacl::link { struct CertInfo { @@ -43,6 +45,16 @@ struct SSLOptions { // Options used to verify the peer's certificate VerifyOptions verify; + + SSLOptions() = default; + + SSLOptions(const SSLOptionsProto& pb) { + cert.certificate_path = pb.certificate_path(); + cert.private_key_path = pb.private_key_path(); + + verify.verify_depth = pb.verify_depth(); + verify.ca_file_path = pb.ca_file_path(); + } }; } // namespace yacl::link diff --git a/yacl/link/transport/BUILD.bazel b/yacl/link/transport/BUILD.bazel index fa39c59c..84b21718 100644 --- a/yacl/link/transport/BUILD.bazel +++ b/yacl/link/transport/BUILD.bazel @@ -47,9 +47,9 @@ cc_proto_library( ) yacl_cc_library( - name = "channel_chunked_base", - srcs = ["channel_chunked_base.cc"], - hdrs = ["channel_chunked_base.h"], + name = "interconnection_base", + srcs = ["interconnection_base.cc"], + hdrs = ["interconnection_base.h"], deps = [ ":channel", ":ic_transport_proto", @@ -63,7 +63,7 @@ yacl_cc_library( srcs = ["channel_brpc.cc"], hdrs = ["channel_brpc.h"], deps = [ - ":channel_chunked_base", + ":interconnection_base", ], ) @@ -72,7 +72,7 @@ yacl_cc_library( srcs = ["channel_brpc_blackbox.cc"], hdrs = ["channel_brpc_blackbox.h"], deps = [ - ":channel_chunked_base", + ":interconnection_base", "//yacl/link/transport/blackbox_interconnect:blackbox_service_errorcode", "//yacl/link/transport/blackbox_interconnect:blackbox_service_proto", ], diff --git a/yacl/link/transport/channel.cc b/yacl/link/transport/channel.cc index 214bd95e..96d957bb 100644 --- a/yacl/link/transport/channel.cc +++ b/yacl/link/transport/channel.cc @@ -261,6 +261,73 @@ void ChannelBase::OnNormalMessage(const std::string& key, T&& v) { msg_db_cond_.notify_all(); } +class ChunkedMessage { + public: + explicit ChunkedMessage(int64_t message_length) : message_(message_length) {} + + void AddChunk(int64_t offset, ByteContainerView data) { + std::unique_lock lock(mutex_); + if (received_.emplace(offset).second) { + std::memcpy(message_.data() + offset, data.data(), + data.size()); + bytes_written_ += data.size(); + } + } + + bool IsFullyFilled() { + std::unique_lock lock(mutex_); + return bytes_written_ == message_.size(); + } + + Buffer&& Reassemble() { + std::unique_lock lock(mutex_); + return std::move(message_); + } + + protected: + bthread::Mutex mutex_; + std::set received_; + // chunk index to value. + int64_t bytes_written_{0}; + Buffer message_; +}; + +void ChannelBase::OnChunkedMessage(const std::string& key, + ByteContainerView value, size_t offset, + size_t total_length) { + if (offset + value.size() > total_length) { + YACL_THROW_LOGIC_ERROR( + "invalid chunk info, offset={}, chun size = {}, total_length={}", + offset, value.size(), total_length); + } + + bool should_reassemble = false; + std::shared_ptr data; + { + std::unique_lock lock(chunked_values_mutex_); + auto itr = chunked_values_.find(key); + if (itr == chunked_values_.end()) { + itr = chunked_values_ + .emplace(key, std::make_shared(total_length)) + .first; + } + + data = itr->second; + data->AddChunk(offset, value); + + if (data->IsFullyFilled()) { + chunked_values_.erase(itr); + + // only one thread do the reassemble + should_reassemble = true; + } + } + + if (should_reassemble) { + OnMessage(key, data->Reassemble()); + } +} + void ChannelBase::OnMessage(const std::string& key, ByteContainerView value) { std::unique_lock lock(msg_mutex_); if (key == kAckKey) { diff --git a/yacl/link/transport/channel.h b/yacl/link/transport/channel.h index b7446993..6f7ee13c 100644 --- a/yacl/link/transport/channel.h +++ b/yacl/link/transport/channel.h @@ -158,6 +158,9 @@ class ChannelBase : public IChannel, // wait for dummy msg from peer, timeout by recv_timeout_ms_. void TestRecv() final; + void OnChunkedMessage(const std::string& key, ByteContainerView value, + size_t offset, size_t total_length); + protected: virtual void SendImpl(const std::string& key, ByteContainerView value) = 0; @@ -243,6 +246,10 @@ class ChannelBase : public IChannel, std::atomic send_thread_stoped_ = false; SendTaskSynchronizer send_sync_; + // chunking related. + bthread::Mutex chunked_values_mutex_; + std::map> chunked_values_; + // message database related. bthread::Mutex msg_mutex_; bthread::ConditionVariable msg_db_cond_; diff --git a/yacl/link/transport/channel_brpc.cc b/yacl/link/transport/channel_brpc.cc index dd573af1..a285cd31 100644 --- a/yacl/link/transport/channel_brpc.cc +++ b/yacl/link/transport/channel_brpc.cc @@ -33,7 +33,7 @@ namespace internal { class ReceiverServiceImpl : public ic_pb::ReceiverService { public: explicit ReceiverServiceImpl( - std::map> listener) + std::map> listener) : listeners_(std::move(listener)) {} void Push(::google::protobuf::RpcController* /*cntl_base*/, @@ -54,7 +54,7 @@ class ReceiverServiceImpl : public ic_pb::ReceiverService { } protected: - std::map> listeners_; + std::map> listeners_; }; } // namespace internal diff --git a/yacl/link/transport/channel_brpc.h b/yacl/link/transport/channel_brpc.h index 16e48153..899caec7 100644 --- a/yacl/link/transport/channel_brpc.h +++ b/yacl/link/transport/channel_brpc.h @@ -26,7 +26,7 @@ #include "yacl/link/ssl_options.h" #include "yacl/link/transport/channel.h" -#include "yacl/link/transport/channel_chunked_base.h" +#include "yacl/link/transport/interconnection_base.h" namespace yacl::link::transport { @@ -48,7 +48,7 @@ class ReceiverLoopBrpc final : public IReceiverLoop { std::string Start(const std::string& host, const SSLOptions* ssl_opts = nullptr); - void AddListener(size_t rank, std::shared_ptr listener) { + void AddListener(size_t rank, std::shared_ptr listener) { auto ret = listeners_.emplace(rank, std::move(listener)); if (!ret.second) { YACL_THROW_LOGIC_ERROR("duplicated listener for rank={}", rank); @@ -56,20 +56,20 @@ class ReceiverLoopBrpc final : public IReceiverLoop { } protected: - std::map> listeners_; + std::map> listeners_; brpc::Server server_; private: void StopImpl(); }; -class ChannelBrpc final : public ChannelChunkedBase { +class ChannelBrpc final : public InterconnectionBase { public: - using ChannelChunkedBase::ChannelChunkedBase; + using InterconnectionBase::InterconnectionBase; - static ChannelChunkedBase::Options GetDefaultOptions() { - return ChannelChunkedBase::Options{10 * 1000, 512 * 1024, "baidu_std", - "single"}; + static InterconnectionBase::Options GetDefaultOptions() { + return InterconnectionBase::Options{10 * 1000, 512 * 1024, "baidu_std", + "single"}; } // from IChannel diff --git a/yacl/link/transport/channel_brpc_blackbox.h b/yacl/link/transport/channel_brpc_blackbox.h index f18e85a7..c50e2137 100644 --- a/yacl/link/transport/channel_brpc_blackbox.h +++ b/yacl/link/transport/channel_brpc_blackbox.h @@ -26,7 +26,7 @@ #include "yacl/link/ssl_options.h" #include "yacl/link/transport/channel.h" -#include "yacl/link/transport/channel_chunked_base.h" +#include "yacl/link/transport/interconnection_base.h" namespace yacl::link::transport::util { @@ -66,14 +66,14 @@ class ReceiverLoopBlackBox final : public IReceiverLoop { std::vector threads_; }; -class ChannelBrpcBlackBox final : public ChannelChunkedBase { +class ChannelBrpcBlackBox final : public InterconnectionBase { public: - static ChannelChunkedBase::Options GetDefaultOptions() { - return ChannelChunkedBase::Options{10 * 1000, 512 * 1024, "http", ""}; + static InterconnectionBase::Options GetDefaultOptions() { + return InterconnectionBase::Options{10 * 1000, 512 * 1024, "http", ""}; } public: - using ChannelChunkedBase::ChannelChunkedBase; + using InterconnectionBase::InterconnectionBase; ~ChannelBrpcBlackBox() override { if (is_recv_.load()) { diff --git a/yacl/link/transport/channel_brpc_test.cc b/yacl/link/transport/channel_brpc_test.cc index 42914013..39128ec7 100644 --- a/yacl/link/transport/channel_brpc_test.cc +++ b/yacl/link/transport/channel_brpc_test.cc @@ -484,8 +484,7 @@ class DummyReceiverLoopBrpc final : public IReceiverLoop { public: ~DummyReceiverLoopBrpc() override { Stop(); } - virtual void AddListener(size_t rank, - std::shared_ptr listener) { + virtual void AddListener(size_t rank, std::shared_ptr listener) { auto ret = listeners_.emplace(rank, std::move(listener)); if (!ret.second) { YACL_THROW_LOGIC_ERROR("duplicated listener for rank={}", rank); @@ -518,7 +517,7 @@ class DummyReceiverLoopBrpc final : public IReceiverLoop { } protected: - std::map> listeners_; + std::map> listeners_; brpc::Server server_; }; diff --git a/yacl/link/transport/channel_chunked_base.cc b/yacl/link/transport/interconnection_base.cc similarity index 74% rename from yacl/link/transport/channel_chunked_base.cc rename to yacl/link/transport/interconnection_base.cc index cc9db038..92b9a9b9 100644 --- a/yacl/link/transport/channel_chunked_base.cc +++ b/yacl/link/transport/interconnection_base.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "yacl/link/transport/channel_chunked_base.h" +#include "yacl/link/transport/interconnection_base.h" #include #include @@ -105,44 +105,13 @@ class SendChunkedWindow std::optional async_exception_; }; -class ChunkedMessage { - public: - explicit ChunkedMessage(int64_t message_length) : message_(message_length) {} - - void AddChunk(int64_t offset, ByteContainerView data) { - std::unique_lock lock(mutex_); - if (received_.emplace(offset).second) { - std::memcpy(message_.data() + offset, data.data(), - data.size()); - bytes_written_ += data.size(); - } - } - - bool IsFullyFilled() { - std::unique_lock lock(mutex_); - return bytes_written_ == message_.size(); - } - - Buffer&& Reassemble() { - std::unique_lock lock(mutex_); - return std::move(message_); - } - - protected: - bthread::Mutex mutex_; - std::set received_; - // chunk index to value. - int64_t bytes_written_{0}; - Buffer message_; -}; - -void ChannelChunkedBase::SendImpl(const std::string& key, - ByteContainerView value) { +void InterconnectionBase::SendImpl(const std::string& key, + ByteContainerView value) { SendImpl(key, value, 0); } -void ChannelChunkedBase::SendImpl(const std::string& key, - ByteContainerView value, uint32_t timeout) { +void InterconnectionBase::SendImpl(const std::string& key, + ByteContainerView value, uint32_t timeout) { if (value.size() > options_.http_max_payload_size) { SendChunked(key, value); return; @@ -161,7 +130,7 @@ void ChannelChunkedBase::SendImpl(const std::string& key, class SendChunkedTask { public: - SendChunkedTask(ChannelChunkedBase* channel, + SendChunkedTask(InterconnectionBase* channel, std::unique_ptr token, std::unique_ptr request) : channel_(channel), @@ -186,13 +155,13 @@ class SendChunkedTask { }; private: - ChannelChunkedBase* channel_; + InterconnectionBase* channel_; std::unique_ptr token_; std::unique_ptr request_; }; -void ChannelChunkedBase::SendChunked(const std::string& key, - ByteContainerView value) { +void InterconnectionBase::SendChunked(const std::string& key, + ByteContainerView value) { const size_t bytes_per_chunk = options_.http_max_payload_size; const size_t num_bytes = value.size(); const size_t num_chunks = (num_bytes + bytes_per_chunk - 1) / bytes_per_chunk; @@ -229,8 +198,8 @@ void ChannelChunkedBase::SendChunked(const std::string& key, window->Finished(); } -void ChannelChunkedBase::OnRequest(const ic_pb::PushRequest* request, - ic_pb::PushResponse* response) { +void InterconnectionBase::OnRequest(const ic_pb::PushRequest* request, + ic_pb::PushResponse* response) { auto trans_type = request->trans_type(); response->mutable_header()->set_error_code(ic::ErrorCode::OK); @@ -250,43 +219,7 @@ void ChannelChunkedBase::OnRequest(const ic_pb::PushRequest* request, } } -void ChannelChunkedBase::OnChunkedMessage(const std::string& key, - ByteContainerView value, - size_t offset, size_t total_length) { - if (offset + value.size() > total_length) { - YACL_THROW_LOGIC_ERROR( - "invalid chunk info, offset={}, chun size = {}, total_length={}", - offset, value.size(), total_length); - } - - bool should_reassemble = false; - std::shared_ptr data; - { - std::unique_lock lock(chunked_values_mutex_); - auto itr = chunked_values_.find(key); - if (itr == chunked_values_.end()) { - itr = chunked_values_ - .emplace(key, std::make_shared(total_length)) - .first; - } - - data = itr->second; - data->AddChunk(offset, value); - - if (data->IsFullyFilled()) { - chunked_values_.erase(itr); - - // only one thread do the reassemble - should_reassemble = true; - } - } - - if (should_reassemble) { - OnMessage(key, data->Reassemble()); - } -} - -auto ChannelChunkedBase::MakeOptions( +auto InterconnectionBase::MakeOptions( Options& default_opt, uint32_t http_timeout_ms, uint32_t http_max_payload_size, const std::string& brpc_channel_protocol, const std::string& brpc_channel_connection_type) -> Options { diff --git a/yacl/link/transport/channel_chunked_base.h b/yacl/link/transport/interconnection_base.h similarity index 83% rename from yacl/link/transport/channel_chunked_base.h rename to yacl/link/transport/interconnection_base.h index 64601478..176fd24c 100644 --- a/yacl/link/transport/channel_chunked_base.h +++ b/yacl/link/transport/interconnection_base.h @@ -34,7 +34,7 @@ class PushResponse; namespace yacl::link::transport { -class ChannelChunkedBase : public ChannelBase { +class InterconnectionBase : public ChannelBase { public: struct Options { uint32_t http_timeout_ms; // 10 seconds @@ -65,17 +65,15 @@ class ChannelChunkedBase : public ChannelBase { // send chunked, synchronized. void SendChunked(const std::string& key, ByteContainerView value); - void OnChunkedMessage(const std::string& key, ByteContainerView value, - size_t offset, size_t total_length); - public: - ChannelChunkedBase(size_t self_rank, size_t peer_rank, Options options, - bool exit_if_async_error = true) + InterconnectionBase(size_t self_rank, size_t peer_rank, Options options, + bool exit_if_async_error = true) : ChannelBase(self_rank, peer_rank, exit_if_async_error), options_(std::move(options)) {} - ChannelChunkedBase(size_t self_rank, size_t peer_rank, size_t recv_timeout_ms, - Options options, bool exit_if_async_error = true) + InterconnectionBase(size_t self_rank, size_t peer_rank, + size_t recv_timeout_ms, Options options, + bool exit_if_async_error = true) : ChannelBase(self_rank, peer_rank, recv_timeout_ms, exit_if_async_error), options_(std::move(options)) {} @@ -95,10 +93,6 @@ class ChannelChunkedBase : public ChannelBase { ::org::interconnection::link::PushResponse* response); protected: - // chunking related. - bthread::Mutex chunked_values_mutex_; - std::map> chunked_values_; - Options options_; };