Skip to content

Commit

Permalink
Add clippy changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed Apr 30, 2024
1 parent c979375 commit ac379f4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 30 deletions.
15 changes: 9 additions & 6 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,16 @@ pub fn bior(table_size: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>) {
let bior_lut: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();

// Convert to 1-D vector
let mut bior_lut_vec = bior_lut.into_iter().map(|(_, v)| v).collect::<Vec<_>>();
let mut bior_lut_vec = bior_lut.into_values().collect::<Vec<_>>();

// Break into two LUTs
bior_lut_vec.rotate_right(1 << (table_size - 1));
let mask = (1 << (bit_width / 2)) - 1;
let lsb = bior_lut_vec.iter().map(|x| x & mask).collect();
let msb = bior_lut_vec.iter().map(|x| x >> (bit_width / 2) & mask).collect();
let msb = bior_lut_vec
.iter()
.map(|x| x >> (bit_width / 2) & mask)
.collect();
(lsb, msb)
}

Expand All @@ -180,13 +183,13 @@ pub fn db2() -> (Vec<Vec<u64>>, Vec<u64>) {

// Convert LSB LUTs to 2-D vector
let lut_lsb_vecs = vec![
db2_lut_1.into_iter().map(|(_, v)| v).collect::<Vec<_>>(),
db2_lut_2.into_iter().map(|(_, v)| v).collect::<Vec<_>>(),
db2_lut_3.into_iter().map(|(_, v)| v).collect::<Vec<_>>(),
db2_lut_1.into_values().collect::<Vec<_>>(),
db2_lut_2.into_values().collect::<Vec<_>>(),
db2_lut_3.into_values().collect::<Vec<_>>(),
];

// Convert MSB LUT to 1-D vector
let lut_msb_vec = db2_lut_4.into_iter().map(|(_, v)| v).collect::<Vec<_>>();
let lut_msb_vec = db2_lut_4.into_values().collect::<Vec<_>>();

(lut_lsb_vecs, lut_msb_vec)
}
Expand Down
53 changes: 33 additions & 20 deletions src/lr_bior.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn main() {
let nb_blocks = bit_width >> 1;

// Number of blocks for n-J LSBs
let nb_blocks_lsb = (bit_width-j) >> 1;
let nb_blocks_lsb = (bit_width - j) >> 1;
println!("Number of blocks for LSB path: {:?}", nb_blocks_lsb);

// Number of blocks for J MSBs
Expand All @@ -70,8 +70,7 @@ fn main() {

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) =
gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());

// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
Expand Down Expand Up @@ -112,14 +111,23 @@ fn main() {
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
let dummy_blocks_msb = &dummy.into_blocks()[((nb_blocks as usize)-(nb_blocks_msb as usize))..(nb_blocks as usize)];
let dummy_blocks_msb = &dummy.into_blocks()
[((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let mut msb_luts = Vec::new();
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32))));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_1(x, &lut_msb, 2u64.pow((lut_bit_width) as u32))));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_2(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32))));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_2(x, &lut_msb, 2u64.pow((lut_bit_width) as u32))));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| {
eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32))
}));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| {
eval_lut_minus_1(x, &lut_msb, 2u64.pow((lut_bit_width) as u32))
}));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| {
eval_lut_minus_2(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32))
}));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| {
eval_lut_minus_2(x, &lut_msb, 2u64.pow((lut_bit_width) as u32))
}));
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
Expand Down Expand Up @@ -151,8 +159,9 @@ fn main() {
server_key.scalar_add_parallelized(&prediction_lsb, 1)
},
|| {
let prediction_blocks_msb =
&prediction_blocks[((nb_blocks as usize)-(nb_blocks_msb as usize))..(nb_blocks as usize)];
let prediction_blocks_msb = &prediction_blocks[((nb_blocks as usize)
- (nb_blocks_msb as usize))
..(nb_blocks as usize)];
let prediction_msb =
RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
Expand All @@ -170,9 +179,13 @@ fn main() {
.keyswitch_to_pbs_params(&prediction_lut_msb)
.into_blocks();
prediction_lut_lsb_blocks.extend(prediction_lut_msb_blocks);
let prediction_pt1_msb = RadixCiphertext::from_blocks(prediction_lut_lsb_blocks);
let prediction_pt1_msb =
RadixCiphertext::from_blocks(prediction_lut_lsb_blocks);
// Multiply by -1 modulo 2^(n-J)
let prediction_pt1_lsb = server_key.scalar_mul_parallelized(&prediction_lsb, 2u64.pow((bit_width - j) as u32)-1);
let prediction_pt1_lsb = server_key.scalar_mul_parallelized(
&prediction_lsb,
2u64.pow((bit_width - j) as u32) - 1,
);
// Multiply MSBs and LSBs
server_key.mul_parallelized(&prediction_pt1_msb, &prediction_pt1_lsb)
},
Expand All @@ -186,7 +199,8 @@ fn main() {
.keyswitch_to_pbs_params(&prediction_lut_msb)
.into_blocks();
prediction_lut_lsb_blocks.extend(prediction_lut_msb_blocks);
let prediction_pt2_msb = RadixCiphertext::from_blocks(prediction_lut_lsb_blocks);
let prediction_pt2_msb =
RadixCiphertext::from_blocks(prediction_lut_lsb_blocks);
// Multiply MSBs and LSBs
server_key.mul_parallelized(&prediction_pt2_msb, &prediction_lsb)
},
Expand Down Expand Up @@ -216,15 +230,13 @@ fn main() {
let (prediction_lsb, prediction_msb) = rayon::join(
|| {
let prediction_blocks_lsb = &prediction_blocks[0..(nb_blocks_lsb as usize)];
let prediction_lsb =
RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec());
let prediction_lsb = RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec());
server_key.scalar_add_parallelized(&prediction_lsb, 1)
},
|| {
let prediction_blocks_msb =
&prediction_blocks[((nb_blocks as usize)-(nb_blocks_msb as usize))..(nb_blocks as usize)];
let prediction_msb =
RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec());
let prediction_blocks_msb = &prediction_blocks
[((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb)
},
Expand All @@ -242,7 +254,8 @@ fn main() {
prediction_lut_lsb_blocks.extend(prediction_lut_msb_blocks);
let prediction_pt1_msb = RadixCiphertext::from_blocks(prediction_lut_lsb_blocks);
// Multiply by -1 modulo 2^(n-J)
let prediction_pt1_lsb = server_key.scalar_mul_parallelized(&prediction_lsb, 2u64.pow((bit_width - j) as u32)-1);
let prediction_pt1_lsb = server_key
.scalar_mul_parallelized(&prediction_lsb, 2u64.pow((bit_width - j) as u32) - 1);
// Multiply MSBs and LSBs
server_key.mul_parallelized(&prediction_pt1_msb, &prediction_pt1_lsb)
},
Expand Down
7 changes: 3 additions & 4 deletions src/lr_db2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ fn main() {

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) =
gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());

// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
Expand Down Expand Up @@ -183,7 +182,7 @@ fn main() {
.map(|i| {
let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]);
let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]);
let mut activation_lsb_blocks = wopbs_key
let activation_lsb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_lsb)
.into_blocks();
let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks);
Expand Down Expand Up @@ -241,7 +240,7 @@ fn main() {
.map(|i| {
let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]);
let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]);
let mut activation_lsb_blocks = wopbs_key
let activation_lsb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_lsb)
.into_blocks();
let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks);
Expand Down

0 comments on commit ac379f4

Please sign in to comment.