Skip to content

Commit

Permalink
Add PoC for encrypted LR with DB2 DWT LUTs
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed Apr 24, 2024
1 parent 42e40e9 commit 3ac55d7
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 1 deletion.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ license = "MIT"

[dependencies]
clap = "3.0"
num-integer = "0.1.46"
csv = "1.3"
debug_print = "1.0.0"
dwt = "0.5.2"
Expand Down Expand Up @@ -47,3 +48,7 @@ path = "src/sigmoid.rs"
[[bin]]
name = "sigmoid_encrypted_dwt"
path = "src/sigmoid_encrypted_dwt.rs"

[[bin]]
name = "encrypted_lr_dwt_db2"
path = "src/encrypted_lr_dwt_db2.rs"
32 changes: 31 additions & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fs::File;
use std::{collections::HashMap, fs::File, io::BufReader};

use dwt::{transform, wavelet::Haar, Operation};

Expand Down Expand Up @@ -146,6 +146,36 @@ pub fn haar(table_size: u8, precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>
(lsb, msb)
}

pub fn db2() -> (Vec<Vec<u64>>, Vec<u64>) {
// Read DB2 LUTs
let reader = BufReader::new(File::open("../data/lut_lsb_h1.json").unwrap());
let lut_lsb_h1: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();
let reader = BufReader::new(File::open("../data/lut_lsb_h2.json").unwrap());
let lut_lsb_h2: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();
let reader = BufReader::new(File::open("../data/lut_lsb_h3.json").unwrap());
let lut_lsb_h3: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();
let reader = BufReader::new(File::open("../data/lut_msb_h4.json").unwrap());
let lut_msb_h4: HashMap<u64, u64> = serde_json::from_reader(reader).unwrap();

// Convert LSB LUTs to 2-D vector
let lut_lsb_len = lut_lsb_h1.keys().max().unwrap_or(&0);
let mut lut_lsb_vecs: Vec<Vec<u64>> = vec![vec![0; (*lut_lsb_len + 1) as usize]; 3];
for i in 0..=*lut_lsb_len {
lut_lsb_vecs[0][i as usize] = lut_lsb_h1.get(&i).cloned().unwrap_or(0);
lut_lsb_vecs[1][i as usize] = lut_lsb_h2.get(&i).cloned().unwrap_or(0);
lut_lsb_vecs[2][i as usize] = lut_lsb_h3.get(&i).cloned().unwrap_or(0);
}

// Convert MSB LUT to 1-D vector
let lut_msb_len = lut_msb_h4.keys().max().unwrap_or(&0);
let mut lut_msb_vec: Vec<u64> = vec![0; (*lut_msb_len + 1) as usize];
for (key, value) in lut_msb_h4 {
lut_msb_vec[key as usize] = value;
}

(lut_lsb_vecs, lut_msb_vec)
}

pub fn quantized_table(table_size: u8, precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>) {
let mut data = Vec::new();
let max = 1 << (table_size);
Expand Down
291 changes: 291 additions & 0 deletions src/encrypted_lr_dwt_db2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
use std::time::Instant;

use clap::{App, Arg};
use fhe_lut::common::*;
use rayon::prelude::*;
// use serde::{Deserialize, Serialize};
use tfhe::{
integer::{
// ciphertext::BaseRadixCiphertext,
gen_keys_radix,
wopbs::*,
IntegerCiphertext,
IntegerRadixCiphertext,
RadixCiphertext,
},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};

// fn eval_lut(x: u64, lut_entries: &Vec<u64>) -> u64 {
// lut_entries[x as usize]
// }

// fn eval_lut_sll_1(x: u64, lut_entries: &Vec<u64>) -> u64 {
// lut_entries[(x << 1) as usize]
// }

// fn eval_lut_sll_2(x: u64, lut_entries: &Vec<u64>) -> u64 {
// lut_entries[(x << 2) as usize]
// }

fn eval_lut_dummy(x: u64) -> u64 {
x * 2
}

fn eval_lut_sll_1_dummy(x: u64) -> u64 {
(x << 1) * 2
}

fn eval_lut_sll_2_dummy(x: u64) -> u64 {
(x << 2) * 2
}

fn main() {
let matches = App::new("Ripple")
.about("Encrypted Logistic Regression with DB2 DWT LUTs")
.arg(
Arg::new("num-samples")
.long("num-samples")
.short('n')
.takes_value(true)
.value_name("INT")
.help("Number of samples")
.default_value("1")
.required(false),
)
.get_matches();

let num_samples = matches
.value_of("num-samples")
.unwrap_or("1")
.parse::<usize>()
.expect("Number of samples must be an integer");

// ------- Client side ------- //
let bit_width = 24u8;
let precision = 8;
let j = 8; // wave depth
assert!(precision <= bit_width / 2);

// let (lut_lsbs, lut_msb) = db2();

// Number of blocks for full precision
let nb_blocks = bit_width >> 1;

// Number of blocks for J LSBs
let nb_blocks_lsb = j >> 1;
println!("Number of blocks for LSB path: {:?}", nb_blocks_lsb);

// Number of blocks for n-J MSBs
let nb_blocks_msb = (bit_width - j) >> 1;
println!("Number of blocks for MSB path: {:?}", nb_blocks_msb);

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) =
gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks_msb.into());

// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
&client_key,
&server_key,
&WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
);
println!(
"Key generation done in {:?} sec.",
start.elapsed().as_secs_f64()
);

let (weights, bias) = load_weights_and_biases();
let (weights_int, bias_int) = quantize_weights_and_bias(&weights, bias, precision, bit_width);
let (dataset, targets) = prepare_penguins_dataset();

let start = Instant::now();
let mut encrypted_dataset: Vec<Vec<_>> = dataset
.par_iter() // Use par_iter() for parallel iteration
.map(|sample| {
sample
.par_iter()
.map(|&s| {
let quantized = quantize(s, precision, bit_width);
let mut lsb = client_key
.encrypt(quantized & (1 << ((nb_blocks << 1) - 1)))
.into_blocks(); // Get LSBs
let msb = client_key
.encrypt(quantized >> (nb_blocks << 1))
.into_blocks(); // Get MSBs
lsb.extend(msb);
RadixCiphertext::from_blocks(lsb)
})
.collect()
})
.collect();
println!(
"Encryption done in {:?} sec.",
start.elapsed().as_secs_f64()
);

// ------- Server side ------- //

let lut_gen_start = Instant::now();
println!("Generating LUT.");
let mut dummy: RadixCiphertext = server_key.create_trivial_radix(2_u64, nb_blocks.into());
for _ in 0..weights_int.len() {
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
let dummy_blocks = &dummy.into_blocks();
let dummy_blocks_lsb = &dummy_blocks[0..((j >> 1) as usize)];
let dummy_blocks_msb = &dummy_blocks[((j >> 1) as usize)..(nb_blocks as usize)];
let dummy_lsb = RadixCiphertext::from_blocks(dummy_blocks_lsb.to_vec());
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks_msb.to_vec());
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let dummy_lsb = server_key.scalar_add_parallelized(&dummy_lsb, 1);
let mut lsb_luts = Vec::new();
let mut msb_luts = Vec::new();
for _ in 0..3 {
lsb_luts.push(wopbs_key.generate_lut_radix(&dummy_lsb, |x: u64| eval_lut_dummy(x)));
}
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_dummy(x)));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_sll_1_dummy(x)));
msb_luts.push(wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_lut_sll_2_dummy(x)));
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
);

// Inference
assert!(num_samples <= encrypted_dataset.len());
let all_probabilities = if num_samples > 1 {
encrypted_dataset
.par_iter_mut()
.enumerate()
.take(num_samples)
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.scalar_mul_parallelized(s, weight);
prediction = server_key.add_parallelized(&ct_prod, &prediction);
}
// Split into J LSBs and n-J MSBs
let prediction_blocks = &prediction.into_blocks();
let prediction_blocks_lsb = &prediction_blocks[0..((j >> 1) as usize)];
let prediction_blocks_msb =
&prediction_blocks[((j >> 1) as usize)..(nb_blocks as usize)];
let prediction_lsb = RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec());
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
let prediction_lsb = server_key.scalar_add_parallelized(&prediction_lsb, 1);
// Evaluate LUTs and multiply
let prediction_msb =
wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let prediction_lsb =
wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_lsb);
let mut prods = Vec::new();
for i in 0..3 {
let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]);
let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]);
let mut activation_lsb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_lsb)
.into_blocks();
// Pad LSBs to n-J bits
let padding: RadixCiphertext =
server_key.create_trivial_radix(0, ((bit_width - 2 * j) >> 1).into());
let padding_blocks = padding.into_blocks();
activation_lsb_blocks.extend(padding_blocks);
let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks);
let activation_msb = wopbs_key.keyswitch_to_pbs_params(&activation_msb);
// Multiply and pad to n bits
let mut ct_prod_blocks = server_key
.mul_parallelized(&activation_lsb, &activation_msb)
.into_blocks();
let padding: RadixCiphertext =
server_key.create_trivial_radix(0, (j >> 1).into());
let padding_blocks = padding.into_blocks();
ct_prod_blocks.extend(padding_blocks);
prods.push(RadixCiphertext::from_blocks(ct_prod_blocks.to_vec()));
}
// Sum products
let probability = server_key.add_parallelized(&prods[0], &prods[1]);
let probability = server_key.add_parallelized(&probability, &prods[2]);
println!(
"Finished inference #{:?} in {:?} sec.",
cnt,
start.elapsed().as_secs_f64()
);
probability
})
.collect::<Vec<_>>()
} else {
let start = Instant::now();
println!("Started inference.");

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in encrypted_dataset[0].iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.scalar_mul_parallelized(s, weight);
prediction = server_key.add_parallelized(&ct_prod, &prediction);
}
// Split into J LSBs and n-J MSBs
let prediction_blocks = &prediction.into_blocks();
let prediction_blocks_lsb = &prediction_blocks[0..((j >> 1) as usize)];
let prediction_blocks_msb = &prediction_blocks[((j >> 1) as usize)..(nb_blocks as usize)];
let prediction_lsb = RadixCiphertext::from_blocks(prediction_blocks_lsb.to_vec());
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks_msb.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
let prediction_lsb = server_key.scalar_add_parallelized(&prediction_lsb, 1);
// Evaluate LUTs and multiply
let prediction_msb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let prediction_lsb = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_lsb);
let mut prods = Vec::new();
for i in 0..3 {
let activation_lsb = wopbs_key.wopbs(&prediction_lsb, &lsb_luts[i]);
let activation_msb = wopbs_key.wopbs(&prediction_msb, &msb_luts[i]);
let mut activation_lsb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_lsb)
.into_blocks();
// Pad LSBs to n-J bits
let padding: RadixCiphertext =
server_key.create_trivial_radix(0, ((bit_width - 2 * j) >> 1).into());
let padding_blocks = padding.into_blocks();
activation_lsb_blocks.extend(padding_blocks);
let activation_lsb = RadixCiphertext::from_blocks(activation_lsb_blocks);
let activation_msb = wopbs_key.keyswitch_to_pbs_params(&activation_msb);
// Multiply and pad to n bits
let mut ct_prod_blocks = server_key
.mul_parallelized(&activation_lsb, &activation_msb)
.into_blocks();
let padding: RadixCiphertext = server_key.create_trivial_radix(0, (j >> 1).into());
let padding_blocks = padding.into_blocks();
ct_prod_blocks.extend(padding_blocks);
prods.push(RadixCiphertext::from_blocks(ct_prod_blocks.to_vec()));
}
// Sum products
let probability = server_key.add_parallelized(&prods[0], &prods[1]);
let probability = server_key.add_parallelized(&probability, &prods[2]);
println!(
"Finished inference in {:?} sec.",
start.elapsed().as_secs_f64()
);
vec![probability]
};

// ------- Client side ------- //
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability: u64 = client_key.decrypt(probability);

let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
if class == *target {
total += 1;
}
}
let accuracy = (total as f32 / num_samples as f32) * 100.0;
println!("Accuracy {accuracy}%");
}

0 comments on commit 3ac55d7

Please sign in to comment.