From ac379f4fed52a5bef508641fbcbb95c468ed4db4 Mon Sep 17 00:00:00 2001 From: cgouert Date: Tue, 30 Apr 2024 15:24:50 -0400 Subject: [PATCH] Add clippy changes --- src/common.rs | 15 ++++++++------ src/lr_bior.rs | 53 +++++++++++++++++++++++++++++++------------------- src/lr_db2.rs | 7 +++---- 3 files changed, 45 insertions(+), 30 deletions(-) diff --git a/src/common.rs b/src/common.rs index d21c7d3..1426a11 100644 --- a/src/common.rs +++ b/src/common.rs @@ -157,13 +157,16 @@ pub fn bior(table_size: u8, bit_width: u8) -> (Vec, Vec) { let bior_lut: HashMap = serde_json::from_reader(reader).unwrap(); // Convert to 1-D vector - let mut bior_lut_vec = bior_lut.into_iter().map(|(_, v)| v).collect::>(); + let mut bior_lut_vec = bior_lut.into_values().collect::>(); // Break into two LUTs bior_lut_vec.rotate_right(1 << (table_size - 1)); let mask = (1 << (bit_width / 2)) - 1; let lsb = bior_lut_vec.iter().map(|x| x & mask).collect(); - let msb = bior_lut_vec.iter().map(|x| x >> (bit_width / 2) & mask).collect(); + let msb = bior_lut_vec + .iter() + .map(|x| x >> (bit_width / 2) & mask) + .collect(); (lsb, msb) } @@ -180,13 +183,13 @@ pub fn db2() -> (Vec>, Vec) { // Convert LSB LUTs to 2-D vector let lut_lsb_vecs = vec![ - db2_lut_1.into_iter().map(|(_, v)| v).collect::>(), - db2_lut_2.into_iter().map(|(_, v)| v).collect::>(), - db2_lut_3.into_iter().map(|(_, v)| v).collect::>(), + db2_lut_1.into_values().collect::>(), + db2_lut_2.into_values().collect::>(), + db2_lut_3.into_values().collect::>(), ]; // Convert MSB LUT to 1-D vector - let lut_msb_vec = db2_lut_4.into_iter().map(|(_, v)| v).collect::>(); + let lut_msb_vec = db2_lut_4.into_values().collect::>(); (lut_lsb_vecs, lut_msb_vec) } diff --git a/src/lr_bior.rs b/src/lr_bior.rs index 701d490..9f3ab02 100644 --- a/src/lr_bior.rs +++ b/src/lr_bior.rs @@ -61,7 +61,7 @@ fn main() { let nb_blocks = bit_width >> 1; // Number of blocks for n-J LSBs - let nb_blocks_lsb = (bit_width-j) >> 1; + let nb_blocks_lsb = (bit_width - j) >> 1; println!("Number of blocks for LSB path: {:?}", nb_blocks_lsb); // Number of blocks for J MSBs @@ -70,8 +70,7 @@ fn main() { let start = Instant::now(); // Generate radix keys - let (client_key, server_key) = - gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); // Generate key for PBS (without padding) let wopbs_key = WopbsKey::new_wopbs_key( @@ -112,14 +111,23 @@ fn main() { let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64); dummy = server_key.add_parallelized(&dummy_2, &dummy); } - let dummy_blocks_msb = &dummy.into_blocks()[((nb_blocks as usize)-(nb_blocks_msb as usize))..(nb_blocks as usize)]; + let dummy_blocks_msb = &dummy.into_blocks() + [((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)]; let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec()); let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); let mut msb_luts = Vec::new(); - msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32)))); - msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_1(x, &lut_msb, 2u64.pow((lut_bit_width) as u32)))); - msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_2(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32)))); - msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_minus_2(x, &lut_msb, 2u64.pow((lut_bit_width) as u32)))); + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| { + eval_lut_minus_1(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32)) + })); + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| { + eval_lut_minus_1(x, &lut_msb, 2u64.pow((lut_bit_width) as u32)) + })); + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| { + eval_lut_minus_2(x, &lut_lsb, 2u64.pow((lut_bit_width) as u32)) + })); + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| { + eval_lut_minus_2(x, &lut_msb, 2u64.pow((lut_bit_width) as u32)) + })); println!( "LUT generation done in {:?} sec.", lut_gen_start.elapsed().as_secs_f64() @@ -151,8 +159,9 @@ fn main() { server_key.scalar_add_parallelized(&prediction_lsb, 1) }, || { - let prediction_blocks_msb = - &prediction_blocks[((nb_blocks as usize)-(nb_blocks_msb as usize))..(nb_blocks as usize)]; + let prediction_blocks_msb = &prediction_blocks[((nb_blocks as usize) + - (nb_blocks_msb as usize)) + ..(nb_blocks as usize)]; let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec()); let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1); @@ -170,9 +179,13 @@ fn main() { .keyswitch_to_pbs_params(&prediction_lut_msb) .into_blocks(); prediction_lut_lsb_blocks.extend(prediction_lut_msb_blocks); - let prediction_pt1_msb = RadixCiphertext::from_blocks(prediction_lut_lsb_blocks); + let prediction_pt1_msb = + RadixCiphertext::from_blocks(prediction_lut_lsb_blocks); // Multiply by -1 modulo 2^(n-J) - let prediction_pt1_lsb = server_key.scalar_mul_parallelized(&prediction_lsb, 2u64.pow((bit_width - j) as u32)-1); + let prediction_pt1_lsb = server_key.scalar_mul_parallelized( + &prediction_lsb, + 2u64.pow((bit_width - j) as u32) - 1, + ); // Multiply MSBs and LSBs server_key.mul_parallelized(&prediction_pt1_msb, &prediction_pt1_lsb) }, @@ -186,7 +199,8 @@ fn main() { .keyswitch_to_pbs_params(&prediction_lut_msb) .into_blocks(); prediction_lut_lsb_blocks.extend(prediction_lut_msb_blocks); - let prediction_pt2_msb = RadixCiphertext::from_blocks(prediction_lut_lsb_blocks); + let prediction_pt2_msb = + RadixCiphertext::from_blocks(prediction_lut_lsb_blocks); // Multiply MSBs and LSBs server_key.mul_parallelized(&prediction_pt2_msb, &prediction_lsb) }, @@ -216,15 +230,13 @@ fn main() { let (prediction_lsb, prediction_msb) = rayon::join( || { let prediction_blocks_lsb = &prediction_blocks[0..(nb_blocks_lsb as usize)]; - let prediction_lsb = - RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec()); + let prediction_lsb = RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec()); server_key.scalar_add_parallelized(&prediction_lsb, 1) }, || { - let prediction_blocks_msb = - &prediction_blocks[((nb_blocks as usize)-(nb_blocks_msb as usize))..(nb_blocks as usize)]; - let prediction_msb = - RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec()); + let prediction_blocks_msb = &prediction_blocks + [((nb_blocks as usize) - (nb_blocks_msb as usize))..(nb_blocks as usize)]; + let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec()); let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1); wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb) }, @@ -242,7 +254,8 @@ fn main() { prediction_lut_lsb_blocks.extend(prediction_lut_msb_blocks); let prediction_pt1_msb = RadixCiphertext::from_blocks(prediction_lut_lsb_blocks); // Multiply by -1 modulo 2^(n-J) - let prediction_pt1_lsb = server_key.scalar_mul_parallelized(&prediction_lsb, 2u64.pow((bit_width - j) as u32)-1); + let prediction_pt1_lsb = server_key + .scalar_mul_parallelized(&prediction_lsb, 2u64.pow((bit_width - j) as u32) - 1); // Multiply MSBs and LSBs server_key.mul_parallelized(&prediction_pt1_msb, &prediction_pt1_lsb) }, diff --git a/src/lr_db2.rs b/src/lr_db2.rs index 16fb2ca..f0fd559 100644 --- a/src/lr_db2.rs +++ b/src/lr_db2.rs @@ -74,8 +74,7 @@ fn main() { let start = Instant::now(); // Generate radix keys - let (client_key, server_key) = - gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); + let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into()); // Generate key for PBS (without padding) let wopbs_key = WopbsKey::new_wopbs_key( @@ -183,7 +182,7 @@ fn main() { .map(|i| { let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]); let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]); - let mut activation_lsb_blocks = wopbs_key + let activation_lsb_blocks = wopbs_key .keyswitch_to_pbs_params(&activation_lsb) .into_blocks(); let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks); @@ -241,7 +240,7 @@ fn main() { .map(|i| { let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]); let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]); - let mut activation_lsb_blocks = wopbs_key + let activation_lsb_blocks = wopbs_key .keyswitch_to_pbs_params(&activation_lsb) .into_blocks(); let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks);