Skip to content

Commit

Permalink
Fix linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Oct 26, 2023
1 parent 1367738 commit 76f375c
Show file tree
Hide file tree
Showing 34 changed files with 1,344 additions and 1,190 deletions.
7 changes: 5 additions & 2 deletions common/datagen/datagen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,17 @@ fn main() {

let fn_a = format!("{}/input_{}_size_{}_cols_{}.csv", dir, "a", size, cols);
let fn_b = format!("{}/input_{}_size_{}_cols_{}.csv", dir, "b", size, cols);
let fn_b_features = format!("{}/input_{}_size_{}_cols_{}_features.csv", dir, "b", size, cols);
let fn_b_features = format!(
"{}/input_{}_size_{}_cols_{}_features.csv",
dir, "b", size, cols
);

info!("Generating output of size {}", size);
info!("Player a output: {}", fn_a);
info!("Player b output: {}", fn_b);
info!("Player b features: {}", fn_b_features);

let intrsct = size / 2 as usize;
let intrsct = size / 2_usize;
let size_player = size - intrsct;
let data = gen::random_data(size_player, size_player, intrsct);
info!("Data generation done, writing to files");
Expand Down
4 changes: 3 additions & 1 deletion common/src/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ where
it.map(|x| {
x.unwrap()
.iter()
.map(|z| u64::from_str(z.trim()).unwrap_or_else(|_| panic!("Cannot format {} as u64", z)))
.map(|z| {
u64::from_str(z.trim()).unwrap_or_else(|_| panic!("Cannot format {} as u64", z))
})
.collect::<Vec<u64>>()
})
.collect::<Vec<Vec<u64>>>()
Expand Down
14 changes: 5 additions & 9 deletions common/src/s3_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use std::path::Path;
use std::str::FromStr;
use std::time::Duration;

use aws_sdk_s3::Region;
use aws_config::default_provider::credentials::default_provider;
use aws_credential_types::cache::CredentialsCache;
use aws_sdk_s3::error::NoSuchUpload;
use aws_sdk_s3::model::CompletedPart;
use aws_sdk_s3::model::CompletedMultipartUpload;
use aws_sdk_s3::model::CompletedPart;
use aws_sdk_s3::types::ByteStream;
use aws_config::default_provider::credentials::default_provider;
use aws_credential_types::cache::CredentialsCache;
use aws_sdk_s3::Region;
use regex::Regex;

lazy_static::lazy_static! {
Expand Down Expand Up @@ -135,11 +135,7 @@ impl S3Path {
.await
.unwrap();
let uid = u.upload_id().ok_or_else(|| {
aws_sdk_s3::Error::NoSuchUpload(
NoSuchUpload::builder()
.message("No upload ID")
.build(),
)
aws_sdk_s3::Error::NoSuchUpload(NoSuchUpload::builder().message("No upload ID").build())
})?;
let mut completed_parts: Vec<CompletedPart> = Vec::new();
for i in 0..chunks {
Expand Down
40 changes: 12 additions & 28 deletions protocol-rpc/src/connect/create_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ use tonic::transport::Endpoint;
use crate::connect::tls;
use crate::proto::gen_crosspsi::cross_psi_client::CrossPsiClient;
use crate::proto::gen_crosspsi_xor::cross_psi_xor_client::CrossPsiXorClient;
use crate::proto::gen_pjc::pjc_client::PjcClient;
use crate::proto::gen_private_id::private_id_client::PrivateIdClient;
use crate::proto::gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient;
use crate::proto::gen_suid_create::suid_create_client::SuidCreateClient;
use crate::proto::gen_dpmc_company::dpmc_company_client::DpmcCompanyClient;
use crate::proto::gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient;
use crate::proto::gen_dspmc_company::dspmc_company_client::DspmcCompanyClient;
use crate::proto::gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient;
use crate::proto::gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient;
use crate::proto::gen_pjc::pjc_client::PjcClient;
use crate::proto::gen_private_id::private_id_client::PrivateIdClient;
use crate::proto::gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient;
use crate::proto::gen_suid_create::suid_create_client::SuidCreateClient;
use crate::proto::RpcClient;

pub fn create_client(
Expand Down Expand Up @@ -145,21 +145,11 @@ pub fn create_client(
"cross-psi-xor" => RpcClient::CrossPsiXor(CrossPsiXorClient::new(conn)),
"pjc" => RpcClient::Pjc(PjcClient::new(conn)),
"suid-create" => RpcClient::SuidCreate(SuidCreateClient::new(conn)),
"dpmc-company" => RpcClient::DpmcCompany(
DpmcCompanyClient::new(conn),
),
"dpmc-partner" => RpcClient::DpmcPartner(
DpmcPartnerClient::new(conn),
),
"dspmc-company" => RpcClient::DspmcCompany(
DspmcCompanyClient::new(conn),
),
"dspmc-helper" => RpcClient::DspmcHelper(
DspmcHelperClient::new(conn),
),
"dspmc-partner" => RpcClient::DspmcPartner(
DspmcPartnerClient::new(conn),
),
"dpmc-company" => RpcClient::DpmcCompany(DpmcCompanyClient::new(conn)),
"dpmc-partner" => RpcClient::DpmcPartner(DpmcPartnerClient::new(conn)),
"dspmc-company" => RpcClient::DspmcCompany(DspmcCompanyClient::new(conn)),
"dspmc-helper" => RpcClient::DspmcHelper(DspmcHelperClient::new(conn)),
"dspmc-partner" => RpcClient::DspmcPartner(DspmcPartnerClient::new(conn)),
_ => panic!("wrong client"),
})
} else {
Expand Down Expand Up @@ -187,19 +177,13 @@ pub fn create_client(
DpmcPartnerClient::connect(__uri).await.unwrap(),
)),
"dspmc-company" => Ok(RpcClient::DspmcCompany(
DspmcCompanyClient::connect(__uri)
.await
.unwrap(),
DspmcCompanyClient::connect(__uri).await.unwrap(),
)),
"dspmc-helper" => Ok(RpcClient::DspmcHelper(
DspmcHelperClient::connect(__uri)
.await
.unwrap(),
DspmcHelperClient::connect(__uri).await.unwrap(),
)),
"dspmc-partner" => Ok(RpcClient::DspmcPartner(
DspmcPartnerClient::connect(__uri)
.await
.unwrap(),
DspmcPartnerClient::connect(__uri).await.unwrap(),
)),
_ => panic!("wrong client"),
}
Expand Down
8 changes: 4 additions & 4 deletions protocol-rpc/src/proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ pub mod streaming;

use gen_crosspsi::cross_psi_client::CrossPsiClient;
use gen_crosspsi_xor::cross_psi_xor_client::CrossPsiXorClient;
use gen_pjc::pjc_client::PjcClient;
use gen_private_id::private_id_client::PrivateIdClient;
use gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient;
use gen_suid_create::suid_create_client::SuidCreateClient;
use gen_dpmc_company::dpmc_company_client::DpmcCompanyClient;
use gen_dpmc_partner::dpmc_partner_client::DpmcPartnerClient;
use gen_dspmc_company::dspmc_company_client::DspmcCompanyClient;
use gen_dspmc_helper::dspmc_helper_client::DspmcHelperClient;
use gen_dspmc_partner::dspmc_partner_client::DspmcPartnerClient;
use gen_pjc::pjc_client::PjcClient;
use gen_private_id::private_id_client::PrivateIdClient;
use gen_private_id_multi_key::private_id_multi_key_client::PrivateIdMultiKeyClient;
use gen_suid_create::suid_create_client::SuidCreateClient;
use tonic::transport::Channel;
pub enum RpcClient {
PrivateId(PrivateIdClient<Channel>),
Expand Down
162 changes: 98 additions & 64 deletions protocol-rpc/src/rpc/dpmc/client.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// SPDX-License-Identifier: Apache-2.0

use clap::{App, Arg, ArgGroup};
use log::{error, info};
use std::convert::TryInto;
use tonic::Request;

use clap::App;
use clap::Arg;
use clap::ArgGroup;
use common::timer;
use crypto::prelude::TPayload;
use protocol::dpmc::{helper::HelperDpmc, traits::*};
use rpc::{
connect::create_client::create_client,
proto::{
gen_dpmc_company::{
service_response::Ack as CompanyAck,
Init as CompanyInit,
ServiceResponse as CompanyServiceResponse
},
gen_dpmc_partner::{
service_response::Ack as PartnerAck,
Init as PartnerInit,
SendData as PartnerSendData,
},
RpcClient,
},
};
use log::error;
use log::info;
use protocol::dpmc::helper::HelperDpmc;
use protocol::dpmc::traits::*;
use rpc::connect::create_client::create_client;
use rpc::proto::gen_dpmc_company::service_response::Ack as CompanyAck;
use rpc::proto::gen_dpmc_company::Init as CompanyInit;
use rpc::proto::gen_dpmc_company::ServiceResponse as CompanyServiceResponse;
use rpc::proto::gen_dpmc_partner::service_response::Ack as PartnerAck;
use rpc::proto::gen_dpmc_partner::Init as PartnerInit;
use rpc::proto::gen_dpmc_partner::SendData as PartnerSendData;
use rpc::proto::RpcClient;
use tonic::Request;

mod rpc_client_company;
mod rpc_client_partner;
Expand Down Expand Up @@ -62,13 +59,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Arg::with_name("output-shares-path")
.long("output-shares-path")
.takes_value(true)
.help("path to write shares of features.\n
Feature will be written as {path}_partner_features.csv"),
.help(
"path to write shares of features.\n
Feature will be written as {path}_partner_features.csv",
),
Arg::with_name("one-to-many")
.long("one-to-many")
.takes_value(true)
.required(false)
.help("By default, DPMC generates one-to-one matches. Use this\n
.help(
"By default, DPMC generates one-to-one matches. Use this\n
flag to generate one(C)-to-many(P) matches.",
),
Arg::with_name("no-tls")
Expand Down Expand Up @@ -226,34 +226,34 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

for i in 0..partner_client_context.len() {
// Send company public key
let _ =
match rpc_client_partner::send(
company_public_key.clone(),
"company_public_key".to_string(),
&mut partner_client_context[i])
.await?
.into_inner()
.ack
.unwrap()
{
PartnerAck::CompanyPublicKeyAck(x) => x,
_ => panic!("wrong ack"),
};
let _ = match rpc_client_partner::send(
company_public_key.clone(),
"company_public_key".to_string(),
&mut partner_client_context[i],
)
.await?
.into_inner()
.ack
.unwrap()
{
PartnerAck::CompanyPublicKeyAck(x) => x,
_ => panic!("wrong ack"),
};

// Send helper public key
let _ =
match rpc_client_partner::send(
helper_public_key.clone(),
"helper_public_key".to_string(),
&mut partner_client_context[i])
.await?
.into_inner()
.ack
.unwrap()
{
PartnerAck::HelperPublicKeyAck(x) => x,
_ => panic!("wrong ack"),
};
let _ = match rpc_client_partner::send(
helper_public_key.clone(),
"helper_public_key".to_string(),
&mut partner_client_context[i],
)
.await?
.into_inner()
.ack
.unwrap()
{
PartnerAck::HelperPublicKeyAck(x) => x,
_ => panic!("wrong ack"),
};
}
}

Expand All @@ -273,11 +273,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await?;

let offset_len = u64::from_le_bytes(
h_company_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(),
h_company_beta
.pop()
.unwrap()
.buffer
.as_slice()
.try_into()
.unwrap(),
) as usize;
// flattened len
let data_len = u64::from_le_bytes(
h_company_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(),
h_company_beta
.pop()
.unwrap()
.buffer
.as_slice()
.try_into()
.unwrap(),
) as usize;

let offset = h_company_beta
Expand Down Expand Up @@ -331,7 +343,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await?;

let xor_shares_len = u64::from_le_bytes(
h_partner_alpha_beta.pop().unwrap().buffer.as_slice().try_into().unwrap()
h_partner_alpha_beta
.pop()
.unwrap()
.buffer
.as_slice()
.try_into()
.unwrap(),
) as usize;

let xor_shares = h_partner_alpha_beta
Expand All @@ -346,11 +364,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// deserialize ragged array
let num_partner_keys = u64::from_le_bytes(
h_partner_alpha_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(),
h_partner_alpha_beta
.pop()
.unwrap()
.buffer
.as_slice()
.try_into()
.unwrap(),
) as usize;
// flattened len
let data_len = u64::from_le_bytes(
h_partner_alpha_beta.pop().unwrap().buffer.as_slice().try_into().unwrap(),
h_partner_alpha_beta
.pop()
.unwrap()
.buffer
.as_slice()
.try_into()
.unwrap(),
) as usize;

let offset = h_partner_alpha_beta
Expand All @@ -364,7 +394,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Perform 1/alpha, where alpha = partner.alpha.
// Then decrypt XOR secret shares and compute features and mask.
helper_protocol.remove_partner_scalar_from_p_and_set_shares(
h_partner_alpha_beta, offset, enc_alpha_t.buffer, vec![p_scalar_times_g], xor_shares
h_partner_alpha_beta,
offset,
enc_alpha_t.buffer,
vec![p_scalar_times_g],
xor_shares,
)?;
}

Expand All @@ -385,14 +419,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let v_d_prime = helper_protocol.calculate_features_xor_shares()?;

// 13. Set XOR share of features for company
let _ = rpc_client_company::calculate_features_xor_shares(
v_d_prime,
&mut company_client_context,
)
.await?
.into_inner()
.ack
.unwrap();
let _ =
rpc_client_company::calculate_features_xor_shares(v_d_prime, &mut company_client_context)
.await?
.into_inner()
.ack
.unwrap();

// 14. Print Company's ID spine and save partners shares
rpc_client_company::reveal(&mut company_client_context).await?;
Expand All @@ -405,7 +437,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// 16. Print Helper's feature shares
match output_shares_path {
Some(p) => helper_protocol.save_features_shares(&String::from(p)).unwrap(),
Some(p) => helper_protocol
.save_features_shares(&String::from(p))
.unwrap(),
None => error!("Output features path not set. Can't output shares"),
};

Expand Down
Loading

0 comments on commit 76f375c

Please sign in to comment.