Skip to content

Commit

Permalink
Repo Sync (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Aug 18, 2023
1 parent ff20dff commit a9c1d7d
Show file tree
Hide file tree
Showing 32 changed files with 554 additions and 300 deletions.
8 changes: 8 additions & 0 deletions yacl/crypto/primitives/ot/ferret_ote.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ OtSendStore FerretOtExtSend(const std::shared_ptr<link::Context>& ctx,
YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties
YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact);
YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num));
YACL_ENFORCE(ot_num >= FerretCotHelper(lpn_param, ot_num),
"ot_num is {}, which should be much greater than the minium "
"size of base cot {}",
ot_num, FerretCotHelper(lpn_param, ot_num));

// get constants: the number of cot needed for mpcot phase
const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n);
Expand Down Expand Up @@ -152,6 +156,10 @@ OtRecvStore FerretOtExtRecv(const std::shared_ptr<link::Context>& ctx,
YACL_ENFORCE(ctx->WorldSize() == 2); // Make sure that OT has two parties
YACL_ENFORCE(base_cot.Type() == OtStoreType::Compact);
YACL_ENFORCE(base_cot.Size() >= FerretCotHelper(lpn_param, ot_num));
YACL_ENFORCE(ot_num >= FerretCotHelper(lpn_param, ot_num),
"ot_num is {}, which should be much greater than the minium "
"size of base cot {}",
ot_num, FerretCotHelper(lpn_param, ot_num));

// get constants: the number of cot needed for mpcot phase
const auto mpcot_cot_num = MpCotRNHelper(lpn_param.t, lpn_param.n);
Expand Down
4 changes: 2 additions & 2 deletions yacl/crypto/primitives/ot/ferret_ote_rn.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void MpCotRNSend(const std::shared_ptr<link::Context>& ctx,
// for each bin, call single-point cot
for (uint64_t i = 0; i < batch_num; ++i) {
const uint64_t this_size =
(i == batch_size - 1) ? full_size - i * batch_size : batch_size;
(i == batch_num - 1) ? full_size - i * batch_size : batch_size;
const auto& cot_slice = cot.Slice(i * math::Log2Ceil(this_size),
(i + 1) * math::Log2Ceil(this_size));

Expand All @@ -58,7 +58,7 @@ void MpCotRNRecv(const std::shared_ptr<link::Context>& ctx,
// for each bin, call single-point cot
for (uint64_t i = 0; i < batch_num; ++i) {
const uint64_t this_size =
(i == batch_size - 1) ? full_size - i * batch_size : batch_size;
(i == batch_num - 1) ? full_size - i * batch_size : batch_size;
const auto cot_slice = cot.Slice(i * math::Log2Ceil(this_size),
(i + 1) * math::Log2Ceil(this_size));
FerretGywzOtExtRecv(ctx, cot_slice, this_size,
Expand Down
27 changes: 27 additions & 0 deletions yacl/crypto/primitives/ot/ferret_ote_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,31 @@ INSTANTIATE_TEST_SUITE_P(
FerretParams{1 << 24, LpnNoiseAsm::RegularNoise},
FerretParams{1 << 25, LpnNoiseAsm::RegularNoise}));

TEST(FerretOtExtEdgeTest, Test) {
// GIVEN
const int kWorldSize = 2;
const auto assumption = LpnNoiseAsm::RegularNoise;

auto lctxs = link::test::SetupWorld(kWorldSize); // setup network
auto lpn_param = LpnParam(10485760, 452000, 1280, assumption);

// ot_num < minium size of base_cot
const size_t ot_num = FerretCotHelper(lpn_param, 0) - 1;
auto cot_num = FerretCotHelper(lpn_param, ot_num); // make option
auto cots_compact = MockCompactOts(cot_num); // mock cots

// WHEN
auto sender = std::async([&] {
ASSERT_THROW(
FerretOtExtSend(lctxs[0], cots_compact.send, lpn_param, ot_num),
::yacl::Exception);
});
auto receiver = std::async([&] {
ASSERT_THROW(
FerretOtExtRecv(lctxs[1], cots_compact.recv, lpn_param, ot_num),
::yacl::Exception);
});
sender.get();
receiver.get();
}
} // namespace yacl::crypto
1 change: 0 additions & 1 deletion yacl/crypto/primitives/tpre/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ yacl_cc_library(
hdrs = ["hash.h"],
deps = [
":kdf",
"//yacl/base:dynamic_bitset",
"//yacl/crypto/base/ecc:spi",
"//yacl/crypto/base/hash:hash_utils",
"//yacl/math/mpint",
Expand Down
68 changes: 13 additions & 55 deletions yacl/crypto/primitives/tpre/capsule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ namespace yacl::crypto {
std::pair<Capsule::CapsuleStruct, std::vector<uint8_t>> Capsule::EnCapsulate(
const std::unique_ptr<EcGroup>& ecc_group,
const Keys::PublicKey& delegating_public_key) const {
MPInt zero_bn(0);
MPInt order = ecc_group->GetOrder();
MPInt r;
MPInt::RandomLtN(order, &r);
Expand All @@ -35,25 +34,15 @@ std::pair<Capsule::CapsuleStruct, std::vector<uint8_t>> Capsule::EnCapsulate(

EcPoint E = ecc_group->MulBase(r);
EcPoint V = ecc_group->MulBase(u);
std::string E_string_join_V_sting =
std::string(ecc_group->SerializePoint(E)) +
std::string(ecc_group->SerializePoint(V));

MPInt s = u.AddMod(
r.MulMod(CipherHash(E_string_join_V_sting, ecc_group), order), order);
MPInt s = u.AddMod(r.MulMod(CipherHash({E, V}, ecc_group), order), order);

EcPoint K_point = ecc_group->Mul(delegating_public_key.y, u.AddMod(r, order));

std::string K_string = std::string(ecc_group->SerializePoint(K_point));

Capsule::CapsuleStruct capsule_struct = {E, V, s};
std::vector<uint8_t> K = KDF(K_string, 16);

std::pair<Capsule::CapsuleStruct, std::vector<uint8_t>> capsule_pair;
capsule_pair.first = capsule_struct;
capsule_pair.second = K;
std::vector<uint8_t> K = KDF(ecc_group->SerializePoint(K_point), 16);

return capsule_pair;
return {capsule_struct, K};
}

// Decapsulate(skA,capsule)->(K)
Expand All @@ -63,9 +52,8 @@ std::vector<uint8_t> Capsule::DeCapsulate(
const CapsuleStruct& capsule_struct) const {
EcPoint K_point = ecc_group->Mul(
ecc_group->Add(capsule_struct.E, capsule_struct.V), private_key.x);
std::string K_string = std::string(ecc_group->SerializePoint(K_point));

std::vector<uint8_t> K = KDF(K_string, 16);
std::vector<uint8_t> K = KDF(ecc_group->SerializePoint(K_point), 16);

return K;
}
Expand All @@ -76,39 +64,22 @@ std::pair<Capsule::CapsuleStruct, int> Capsule::CheckCapsule(
EcPoint tmp0 = ecc_group->MulBase(capsule_struct.s);

// compute H_2(E,V)
std::string E_string_join_V_sting =
std::string(ecc_group->SerializePoint(capsule_struct.E)) +
std::string(ecc_group->SerializePoint(capsule_struct.V));
MPInt hev = CipherHash(E_string_join_V_sting, ecc_group);
MPInt hev = CipherHash({capsule_struct.E, capsule_struct.V}, ecc_group);

EcPoint e_exp_hev = ecc_group->Mul(capsule_struct.E, hev);
EcPoint tmp1 = ecc_group->Add(capsule_struct.V, e_exp_hev);

std::string tmp0_string = std::string(ecc_group->SerializePoint(tmp0));
std::string tmp1_string = std::string(ecc_group->SerializePoint(tmp1));

int signal;
if (tmp0_string == tmp1_string) {
signal = 1;
} else {
signal = 0;
}

std::pair<Capsule::CapsuleStruct, int> capsule_check_result = {capsule_struct,
signal};

return capsule_check_result;
return {capsule_struct, ecc_group->PointEqual(tmp0, tmp1)};
}

// /**
// * Each Re-encryptor generates the ciphertext fragment, i.e., cfrag
// * */
//
// Each Re-encryptor generates the ciphertext fragment, i.e., cfrag
//
Capsule::CFrag Capsule::ReEncapsulate(const std::unique_ptr<EcGroup>& ecc_group,
const Keys::KFrag& kfrag,
const CapsuleStruct& capsule) const {
// First checks the validity of the capsule with CheckCapsule and outputs ⊥
// if the check fails.

auto capsule_check_result = CheckCapsule(ecc_group, capsule);

YACL_ENFORCE(capsule_check_result.second == 1,
Expand All @@ -130,18 +101,11 @@ std::vector<uint8_t> Capsule::DeCapsulateFrags(
const std::unique_ptr<EcGroup>& ecc_group, const Keys::PrivateKey& sk_B,
const Keys::PublicKey& pk_A, const Keys::PublicKey& pk_B,
const std::vector<CFrag>& cfrags) const {
MPInt one_bn(1);

// Compute (pk_B)^a
EcPoint pk_A_mul_b = ecc_group->Mul(pk_A.y, sk_B.x);
std::string pk_A_mul_b_str =
std::string(ecc_group->SerializePoint(pk_A_mul_b));
std::string pk_A_str = std::string(ecc_group->SerializePoint(pk_A.y));
std::string pk_B_str = std::string(ecc_group->SerializePoint(pk_B.y));

// 1. Compute D = H_6(pk_A, pk_B, (pk_A)^b)

MPInt D = CipherHash(pk_A_str + pk_B_str + pk_A_mul_b_str, ecc_group);
MPInt D = CipherHash({pk_A.y, pk_B.y, pk_A_mul_b}, ecc_group);

// 2. Compute s_{x,i} and lambda_{i,S}
// 2.1 Compute s_{x,i} = H_5(id_i, D)
Expand Down Expand Up @@ -191,20 +155,14 @@ std::vector<uint8_t> Capsule::DeCapsulateFrags(
// 4. Compute d = H_3(X_A,pk_B,(X_A)^b)
std::string X_A_str = std::string(ecc_group->SerializePoint(cfrags[0].X_A));
EcPoint X_A_mul_b = ecc_group->Mul(cfrags[0].X_A, sk_B.x);
std::string X_A_mul_b_str = std::string(ecc_group->SerializePoint(X_A_mul_b));

MPInt d = CipherHash(X_A_str + pk_B_str + X_A_mul_b_str, ecc_group);
MPInt d = CipherHash({cfrags[0].X_A, pk_B.y, X_A_mul_b}, ecc_group);

// 5. Compute DEK, i.e., K=KDF((E'· V')^d)

EcPoint E_prime_add_V_prime = ecc_group->Add(E_prime, V_prime);
EcPoint E_prime_add_V_prime_mul_d = ecc_group->Mul(E_prime_add_V_prime, d);

std::string K_string =
std::string(ecc_group->SerializePoint(E_prime_add_V_prime_mul_d));

std::vector<uint8_t> K = KDF(K_string, 16);

return K;
return KDF(ecc_group->SerializePoint(E_prime_add_V_prime_mul_d), 16);
}

} // namespace yacl::crypto
66 changes: 33 additions & 33 deletions yacl/crypto/primitives/tpre/capsule.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef YACL_CRYPTO_PRIMITIVES_TPRE_CAPSULE_H_
#define YACL_CRYPTO_PRIMITIVES_TPRE_CAPSULE_H_
#pragma once

#include <iostream>
#include <memory>
#include <string>
Expand All @@ -36,71 +36,71 @@ namespace yacl::crypto {
*/
class Capsule {
public:
Capsule() {}
~Capsule() {}
Capsule() = default;
~Capsule() = default;

/// @brief Capsule for encapsulating data keys, Capsule struct includes 2
/// elliptic curve point and a big number
// @brief Capsule for encapsulating data keys, Capsule struct includes 2
// elliptic curve point and a big number
struct CapsuleStruct {
EcPoint E; // E = g^r, g is the generator of elliptic group
EcPoint V; // V = g^u, g is the generator of elliptic group
MPInt s; // s = u + r · H(E, V)
};

/// @brief CFrag is the fragment of Capsule after re-encapsulating
// @brief CFrag is the fragment of Capsule after re-encapsulating
struct CFrag {
EcPoint E_1; // E_1 = E^rk
EcPoint V_1; // V_1 = V^rk
MPInt id; // identity number of each proxy
EcPoint X_A; // X_A = g^x_A
};

/// @brief EnCapsulate algorithm, generate and capsulate the random data
/// encryption key
/// @param ecc_group
/// @param delegating_public_key
/// @return capsule and data ecnryption key
// @brief EnCapsulate algorithm, generate and capsulate the random data
// encryption key
// @param ecc_group
// @param delegating_public_key
// @return capsule and data ecnryption key

std::pair<CapsuleStruct, std::vector<uint8_t>> EnCapsulate(
const std::unique_ptr<EcGroup>& ecc_group,
const Keys::PublicKey& delegating_public_key) const;

/// @brief DeCapsulate algorithm, to obtain the data encryption key
/// @param private_key
/// @param capsule_struct
/// @return data encryption key
// @brief DeCapsulate algorithm, to obtain the data encryption key
// @param private_key
// @param capsule_struct
// @return data encryption key
std::vector<uint8_t> DeCapsulate(const std::unique_ptr<EcGroup>& ecc_group,
const Keys::PrivateKey& private_key,
const CapsuleStruct& capsule_struct) const;

/// @brief Capsule check algorithm
/// @param ecc_group
/// @param capsule_struct
/// @return 0 (check fail) or 1 (check success)
// @brief Capsule check algorithm
// @param ecc_group
// @param capsule_struct
// @return 0 (check fail) or 1 (check success)
std::pair<CapsuleStruct, int> CheckCapsule(
const std::unique_ptr<EcGroup>& ecc_group,
const CapsuleStruct& capsule_struct) const;

/// @brief Re-encapsulate capsule
/// @param ecc_group
/// @param kfrag, re-encryption key fragment
/// @param capsule
/// @return Re-encapsulated capsule
// @brief Re-encapsulate capsule
// @param ecc_group
// @param kfrag, re-encryption key fragment
// @param capsule
// @return Re-encapsulated capsule
CFrag ReEncapsulate(const std::unique_ptr<EcGroup>& ecc_group,
const Keys::KFrag& kfrag,
const CapsuleStruct& capsule) const;

/// @brief Restore the re-encapsulated capsule set to data encryption key
/// @param ecc_group
/// @param sk_B, secret key of Bob
/// @param pk_A, public key of Alice
/// @param pk_B, public key of Bob
/// @param cfrags, re-encapsulated capsule set
/// @return Data encryption key
// @brief Restore the re-encapsulated capsule set to data encryption key
// @param ecc_group
// @param sk_B, secret key of Bob
// @param pk_A, public key of Alice
// @param pk_B, public key of Bob
// @param cfrags, re-encapsulated capsule set
// @return Data encryption key
std::vector<uint8_t> DeCapsulateFrags(
const std::unique_ptr<EcGroup>& ecc_group, const Keys::PrivateKey& sk_B,
const Keys::PublicKey& pk_A, const Keys::PublicKey& pk_B,
const std::vector<CFrag>& cfrags) const;
};

} // namespace yacl::crypto
#endif // YACL_CRYPTO_PRIMITIVES_TPRE_CAPSULE_H_
33 changes: 24 additions & 9 deletions yacl/crypto/primitives/tpre/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <string>
#include <vector>

#include "yacl/base/dynamic_bitset.h"
#include "yacl/crypto/base/hash/hash_utils.h"
#include "yacl/crypto/primitives/tpre/kdf.h"

Expand All @@ -30,18 +29,34 @@ namespace yacl::crypto {
// where n is the degree of EC Group, and x is input
MPInt CipherHash(ByteContainerView input,
const std::unique_ptr<EcGroup>& ecc_group) {
std::array<unsigned char, 32> hash_value_0 = Sm3(input);
std::array<unsigned char, 32> hash_value_1 = Sm3(hash_value_0);
auto hash_value_0 = Sm3(input);
auto hash_value_1 = Sm3(hash_value_0);

dynamic_bitset<uint8_t> binary;
binary.append(hash_value_0.begin(), hash_value_0.end());
binary.append(hash_value_1.begin(), hash_value_1.end());
MPInt hash_bn(binary.to_string(), 2);
std::vector<uint8_t> buf;
buf.insert(buf.end(), hash_value_0.begin(), hash_value_0.end());
buf.insert(buf.end(), hash_value_1.begin(), hash_value_1.end());

MPInt hash_bn;
hash_bn.FromMagBytes(buf);

MPInt one_bn(1);
// h_x = 1 + Bignum(sm3(x)||sm3(sm3(x))) mod n-1
MPInt h_x = one_bn.AddMod(hash_bn, ecc_group->GetOrder() - one_bn);
MPInt h_x = hash_bn.AddMod(1_mp, ecc_group->GetOrder() - 1_mp);

return h_x;
}

MPInt CipherHash(std::initializer_list<EcPoint> inputs,
const std::unique_ptr<EcGroup>& ecc_group) {
auto len = ecc_group->GetSerializeLength();
Buffer buf(len * inputs.size());

uint8_t index = 0;
for (const auto& p : inputs) {
ecc_group->SerializePoint(p, buf.data<uint8_t>() + index * len, len);
index++;
}

return CipherHash(buf, ecc_group);
}

} // namespace yacl::crypto
Loading

0 comments on commit a9c1d7d

Please sign in to comment.