diff --git a/Cargo.toml b/Cargo.toml index 9b4c9f5..37a6bcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ license = "MIT" [dependencies] clap = "3.0" +num-integer = "0.1.46" csv = "1.3" debug_print = "1.0.0" dwt = "0.5.2" @@ -47,3 +48,7 @@ path = "src/sigmoid.rs" [[bin]] name = "sigmoid_encrypted_dwt" path = "src/sigmoid_encrypted_dwt.rs" + +[[bin]] +name = "encrypted_lr_dwt_db2" +path = "src/encrypted_lr_dwt_db2.rs" \ No newline at end of file diff --git a/src/common.rs b/src/common.rs index c2b3272..257b5cc 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,4 +1,4 @@ -use std::fs::File; +use std::{collections::HashMap, fs::File, io::BufReader}; use dwt::{transform, wavelet::Haar, Operation}; @@ -146,6 +146,36 @@ pub fn haar(table_size: u8, precision: u8, bit_width: u8) -> (Vec, Vec (lsb, msb) } +pub fn db2() -> (Vec>, Vec) { + // Read DB2 LUTs + let reader = BufReader::new(File::open("../data/lut_lsb_h1.json").unwrap()); + let lut_lsb_h1: HashMap = serde_json::from_reader(reader).unwrap(); + let reader = BufReader::new(File::open("../data/lut_lsb_h2.json").unwrap()); + let lut_lsb_h2: HashMap = serde_json::from_reader(reader).unwrap(); + let reader = BufReader::new(File::open("../data/lut_lsb_h3.json").unwrap()); + let lut_lsb_h3: HashMap = serde_json::from_reader(reader).unwrap(); + let reader = BufReader::new(File::open("../data/lut_msb_h4.json").unwrap()); + let lut_msb_h4: HashMap = serde_json::from_reader(reader).unwrap(); + + // Convert LSB LUTs to 2-D vector + let lut_lsb_len = lut_lsb_h1.keys().max().unwrap_or(&0); + let mut lut_lsb_vecs: Vec> = vec![vec![0; (*lut_lsb_len + 1) as usize]; 3]; + for i in 0..=*lut_lsb_len { + lut_lsb_vecs[0][i as usize] = lut_lsb_h1.get(&i).cloned().unwrap_or(0); + lut_lsb_vecs[1][i as usize] = lut_lsb_h2.get(&i).cloned().unwrap_or(0); + lut_lsb_vecs[2][i as usize] = lut_lsb_h3.get(&i).cloned().unwrap_or(0); + } + + // Convert MSB LUT to 1-D vector + let lut_msb_len = lut_msb_h4.keys().max().unwrap_or(&0); + let mut lut_msb_vec: Vec = vec![0; (*lut_msb_len + 1) as usize]; + for (key, value) in lut_msb_h4 { + lut_msb_vec[key as usize] = value; + } + + (lut_lsb_vecs, lut_msb_vec) +} + pub fn quantized_table(table_size: u8, precision: u8, bit_width: u8) -> (Vec, Vec) { let mut data = Vec::new(); let max = 1 << (table_size); diff --git a/src/encrypted_lr_dwt_db2.rs b/src/encrypted_lr_dwt_db2.rs new file mode 100644 index 0000000..6fe06c9 --- /dev/null +++ b/src/encrypted_lr_dwt_db2.rs @@ -0,0 +1,291 @@ +use std::time::Instant; + +use clap::{App, Arg}; +use fhe_lut::common::*; +use rayon::prelude::*; +// use serde::{Deserialize, Serialize}; +use tfhe::{ + integer::{ + // ciphertext::BaseRadixCiphertext, + gen_keys_radix, + wopbs::*, + IntegerCiphertext, + IntegerRadixCiphertext, + RadixCiphertext, + }, + shortint::parameters::{ + parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + }, +}; + +// fn eval_lut(x: u64, lut_entries: &Vec) -> u64 { +// lut_entries[x as usize] +// } + +// fn eval_lut_sll_1(x: u64, lut_entries: &Vec) -> u64 { +// lut_entries[(x << 1) as usize] +// } + +// fn eval_lut_sll_2(x: u64, lut_entries: &Vec) -> u64 { +// lut_entries[(x << 2) as usize] +// } + +fn eval_lut_dummy(x: u64) -> u64 { + x * 2 +} + +fn eval_lut_sll_1_dummy(x: u64) -> u64 { + (x << 1) * 2 +} + +fn eval_lut_sll_2_dummy(x: u64) -> u64 { + (x << 2) * 2 +} + +fn main() { + let matches = App::new("Ripple") + .about("Encrypted Logistic Regression with DB2 DWT LUTs") + .arg( + Arg::new("num-samples") + .long("num-samples") + .short('n') + .takes_value(true) + .value_name("INT") + .help("Number of samples") + .default_value("1") + .required(false), + ) + .get_matches(); + + let num_samples = matches + .value_of("num-samples") + .unwrap_or("1") + .parse::() + .expect("Number of samples must be an integer"); + + // ------- Client side ------- // + let bit_width = 24u8; + let precision = 8; + let j = 8; // wave depth + assert!(precision <= bit_width / 2); + + // let (lut_lsbs, lut_msb) = db2(); + + // Number of blocks for full precision + let nb_blocks = bit_width >> 1; + + // Number of blocks for J LSBs + let nb_blocks_lsb = j >> 1; + println!("Number of blocks for LSB path: {:?}", nb_blocks_lsb); + + // Number of blocks for n-J MSBs + let nb_blocks_msb = (bit_width - j) >> 1; + println!("Number of blocks for MSB path: {:?}", nb_blocks_msb); + + let start = Instant::now(); + // Generate radix keys + let (client_key, server_key) = + gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks_msb.into()); + + // Generate key for PBS (without padding) + let wopbs_key = WopbsKey::new_wopbs_key( + &client_key, + &server_key, + &WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + ); + println!( + "Key generation done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + let (weights, bias) = load_weights_and_biases(); + let (weights_int, bias_int) = quantize_weights_and_bias(&weights, bias, precision, bit_width); + let (dataset, targets) = prepare_penguins_dataset(); + + let start = Instant::now(); + let mut encrypted_dataset: Vec> = dataset + .par_iter() // Use par_iter() for parallel iteration + .map(|sample| { + sample + .par_iter() + .map(|&s| { + let quantized = quantize(s, precision, bit_width); + let mut lsb = client_key + .encrypt(quantized & (1 << ((nb_blocks << 1) - 1))) + .into_blocks(); // Get LSBs + let msb = client_key + .encrypt(quantized >> (nb_blocks << 1)) + .into_blocks(); // Get MSBs + lsb.extend(msb); + RadixCiphertext::from_blocks(lsb) + }) + .collect() + }) + .collect(); + println!( + "Encryption done in {:?} sec.", + start.elapsed().as_secs_f64() + ); + + // ------- Server side ------- // + + let lut_gen_start = Instant::now(); + println!("Generating LUT."); + let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, nb_blocks.into()); + for _ in 0..weights_int.len() { + let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64); + dummy = server_key.add_parallelized(&dummy_2, &dummy); + } + let dummy_blocks = &dummy.into_blocks(); + let dummy_blocks_lsb = &dummy_blocks[0..((j >> 1) as usize)]; + let dummy_blocks_msb = &dummy_blocks[((j >> 1) as usize)..(nb_blocks as usize)]; + let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks_lsb.to_vec()); + let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec()); + let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1); + let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1); + let mut lsb_luts = Vec::new(); + let mut msb_luts = Vec::new(); + for _ in 0..3 { + lsb_luts.push(wopbs_key.generate_lut_radix(&dummy_lsb, |x: u64| eval_lut_dummy(x))); + } + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_dummy(x))); + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_sll_1_dummy(x))); + msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_sll_2_dummy(x))); + println!( + "LUT generation done in {:?} sec.", + lut_gen_start.elapsed().as_secs_f64() + ); + + // Inference + assert!(num_samples <= encrypted_dataset.len()); + let all_probabilities = if num_samples > 1 { + encrypted_dataset + .par_iter_mut() + .enumerate() + .take(num_samples) + .map(|(cnt, sample)| { + let start = Instant::now(); + println!("Started inference #{:?}.", cnt); + + let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into()); + for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) { + let ct_prod = server_key.scalar_mul_parallelized(s, weight); + prediction = server_key.add_parallelized(&ct_prod, &prediction); + } + // Split into J LSBs and n-J MSBs + let prediction_blocks = &prediction.into_blocks(); + let prediction_blocks_lsb = &prediction_blocks[0..((j >> 1) as usize)]; + let prediction_blocks_msb = + &prediction_blocks[((j >> 1) as usize)..(nb_blocks as usize)]; + let prediction_lsb = RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec()); + let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec()); + let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1); + let prediction_lsb = server_key.scalar_add_parallelized(&prediction_lsb, 1); + // Evaluate LUTs and multiply + let prediction_msb = + wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb); + let prediction_lsb = + wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_lsb); + let mut prods = Vec::new(); + for i in 0..3 { + let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]); + let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]); + let mut activation_lsb_blocks = wopbs_key + .keyswitch_to_pbs_params(&activation_lsb) + .into_blocks(); + // Pad LSBs to n-J bits + let padding: RadixCiphertext = + server_key.create_trivial_radix(0, ((bit_width - 2 * j) >> 1).into()); + let padding_blocks = padding.into_blocks(); + activation_lsb_blocks.extend(padding_blocks); + let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks); + let activation_msb = wopbs_key.keyswitch_to_pbs_params(&activation_msb); + // Multiply and pad to n bits + let mut ct_prod_blocks = server_key + .mul_parallelized(&activation_lsb, &activation_msb) + .into_blocks(); + let padding: RadixCiphertext = + server_key.create_trivial_radix(0, (j >> 1).into()); + let padding_blocks = padding.into_blocks(); + ct_prod_blocks.extend(padding_blocks); + prods.push(RadixCiphertext::from_blocks(ct_prod_blocks.to_vec())); + } + // Sum products + let probability = server_key.add_parallelized(&prods[0], &prods[1]); + let probability = server_key.add_parallelized(&probability, &prods[2]); + println!( + "Finished inference #{:?} in {:?} sec.", + cnt, + start.elapsed().as_secs_f64() + ); + probability + }) + .collect::>() + } else { + let start = Instant::now(); + println!("Started inference."); + + let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into()); + for (s, &weight) in encrypted_dataset[0].iter_mut().zip(weights_int.iter()) { + let ct_prod = server_key.scalar_mul_parallelized(s, weight); + prediction = server_key.add_parallelized(&ct_prod, &prediction); + } + // Split into J LSBs and n-J MSBs + let prediction_blocks = &prediction.into_blocks(); + let prediction_blocks_lsb = &prediction_blocks[0..((j >> 1) as usize)]; + let prediction_blocks_msb = &prediction_blocks[((j >> 1) as usize)..(nb_blocks as usize)]; + let prediction_lsb = RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec()); + let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec()); + let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1); + let prediction_lsb = server_key.scalar_add_parallelized(&prediction_lsb, 1); + // Evaluate LUTs and multiply + let prediction_msb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb); + let prediction_lsb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_lsb); + let mut prods = Vec::new(); + for i in 0..3 { + let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]); + let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]); + let mut activation_lsb_blocks = wopbs_key + .keyswitch_to_pbs_params(&activation_lsb) + .into_blocks(); + // Pad LSBs to n-J bits + let padding: RadixCiphertext = + server_key.create_trivial_radix(0, ((bit_width - 2 * j) >> 1).into()); + let padding_blocks = padding.into_blocks(); + activation_lsb_blocks.extend(padding_blocks); + let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks); + let activation_msb = wopbs_key.keyswitch_to_pbs_params(&activation_msb); + // Multiply and pad to n bits + let mut ct_prod_blocks = server_key + .mul_parallelized(&activation_lsb, &activation_msb) + .into_blocks(); + let padding: RadixCiphertext = server_key.create_trivial_radix(0, (j >> 1).into()); + let padding_blocks = padding.into_blocks(); + ct_prod_blocks.extend(padding_blocks); + prods.push(RadixCiphertext::from_blocks(ct_prod_blocks.to_vec())); + } + // Sum products + let probability = server_key.add_parallelized(&prods[0], &prods[1]); + let probability = server_key.add_parallelized(&probability, &prods[2]); + println!( + "Finished inference in {:?} sec.", + start.elapsed().as_secs_f64() + ); + vec![probability] + }; + + // ------- Client side ------- // + let mut total = 0; + for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() { + let ptxt_probability: u64 = client_key.decrypt(probability); + + let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize; + println!("[{}] predicted {:?}, target {:?}", num, class, target); + if class == *target { + total += 1; + } + } + let accuracy = (total as f32 / num_samples as f32) * 100.0; + println!("Accuracy {accuracy}%"); +}