Skip to content

Commit

Permalink
Update encrypted LR with DWT
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed Apr 8, 2024
1 parent 5c12563 commit 366544f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/encrypted_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ fn main() {
total += 1;
}
}
let accuracy = (total as f32 / encrypted_dataset.len() as f32) * 100.0;
let accuracy = (total as f32 / num_samples as f32) * 100.0;
println!("Accuracy {accuracy}%");
}
133 changes: 87 additions & 46 deletions src/encrypted_lr_dwt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::time::Instant;

use clap::{App, Arg};
use fhe_lut::common::*;
use rayon::prelude::*;
// use serde::{Deserialize, Serialize};
Expand All @@ -23,13 +24,33 @@ fn eval_exp(x: u64, exp_map: &Vec<u64>) -> u64 {
}

fn main() {
let matches = App::new("Ripple")
.about("Vanilla Encrypted Logistic Regression")
.arg(
Arg::new("num-samples")
.long("num-samples")
.short('n')
.takes_value(true)
.value_name("INT")
.help("Number of samples")
.default_value("1")
.required(false),
)
.get_matches();

let num_samples = matches
.value_of("num-samples")
.unwrap_or("1")
.parse::<usize>()
.expect("Number of samples must be an integer");

// ------- Client side ------- //
let bit_width = 24u8;
let precision = 8;
let table_size = 12;
assert!(precision <= bit_width / 2);

let (lut_lsb, lut_msb) = haar(table_size, precision, bit_width);
let (lut_lsb, _lut_msb) = haar(table_size, precision, bit_width);

// Number of blocks per ciphertext
let nb_blocks = bit_width >> 2;
Expand Down Expand Up @@ -83,71 +104,91 @@ fn main() {

let lut_gen_start = Instant::now();
println!("Generating LUT.");
let mut dummy = server_key.create_trivial_radix(2_u64, (nb_blocks << 1).into());
let mut dummy: RadixCiphertext =
server_key.create_trivial_radix(2_u64, (nb_blocks << 1).into());
for _ in 0..weights_int.len() {
let dummy_2 = server_key.smart_scalar_mul(&mut dummy, 2_u64);
dummy = server_key.unchecked_add(&dummy_2, &dummy);
let dummy_2 = server_key.scalar_mul_parallelized(&dummy, 2_u64);
dummy = server_key.add_parallelized(&dummy_2, &dummy);
}
let dummy_blocks = &dummy.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let dummy_msb = RadixCiphertext::from_blocks(dummy_blocks.to_vec());
let dummy_msb = server_key.unchecked_scalar_add(&dummy_msb, 1);
let dummy_msb = server_key.scalar_add_parallelized(&dummy_msb, 1);
let exp_lut_lsb = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_exp(x, &lut_lsb));
let exp_lut_msb = wopbs_key.generate_lut_radix(&dummy_msb, |x: u64| eval_exp(x, &lut_msb));
println!(
"LUT generation done in {:?} sec.",
lut_gen_start.elapsed().as_secs_f64()
);

let encrypted_dataset_short = encrypted_dataset.get_mut(0..8).unwrap();
let all_probabilities = encrypted_dataset_short
.par_iter_mut()
.enumerate()
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, (nb_blocks << 1).into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.smart_scalar_mul(s, weight);
prediction = server_key.unchecked_add(&ct_prod, &prediction);
}
// Truncate
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
let mut lsb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_lsb)
.into_blocks();
let activation_msb = wopbs_key.wopbs(&prediction, &exp_lut_msb);
let msb_blocks = wopbs_key
.keyswitch_to_pbs_params(&activation_msb)
.into_blocks();
lsb_blocks.extend(msb_blocks);
let probability = RadixCiphertext::from_blocks(lsb_blocks);

println!(
"Finished inference #{:?} in {:?} sec.",
cnt,
start.elapsed().as_secs_f64()
);
probability
})
.collect::<Vec<_>>();
// Inference
assert!(num_samples <= encrypted_dataset.len());
let all_probabilities = if num_samples > 1 {
encrypted_dataset
.par_iter_mut()
.enumerate()
.take(num_samples)
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.scalar_mul_parallelized(s, weight);
prediction = server_key.add_parallelized(&ct_prod, &prediction);
}
// Truncate
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
let probability = wopbs_key.keyswitch_to_pbs_params(&activation_lsb);
println!(
"Finished inference #{:?} in {:?} sec.",
cnt,
start.elapsed().as_secs_f64()
);
probability
})
.collect::<Vec<_>>()
} else {
let start = Instant::now();
println!("Started inference.");

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in encrypted_dataset[0].iter_mut().zip(weights_int.iter()) {
let ct_prod = server_key.scalar_mul_parallelized(s, weight);
prediction = server_key.add_parallelized(&ct_prod, &prediction);
}
// Truncate
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.scalar_add_parallelized(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let activation_lsb = wopbs_key.wopbs(&prediction, &exp_lut_lsb);
let probability = wopbs_key.keyswitch_to_pbs_params(&activation_lsb);

println!(
"Finished inference in {:?} sec.",
start.elapsed().as_secs_f64()
);
vec![probability]
};

// ------- Client side ------- //
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability: u64 = client_key.decrypt(probability);

let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
if class == *target {
total += 1;
}
}
let accuracy = (total as f32 / encrypted_dataset_short.len() as f32) * 100.0;
let accuracy = (total as f32 / num_samples as f32) * 100.0;
println!("Accuracy {accuracy}%");
}
2 changes: 1 addition & 1 deletion src/sigmoid_encrypted_dwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn main() {

// ------- Client side ------- //
for (prob, data) in all_probabilities.iter().zip(dataset.iter()) {
let res: u64 = client_key.decrypt(&prob);
let res: u64 = client_key.decrypt(prob);
let exp_lsb = lut_lsb_plain[(data >> (bit_width - table_size)) as usize];
let exp_msb = lut_msb_plain[(data >> (bit_width - table_size)) as usize];
let exp = (exp_msb << (bit_width / 2)) + exp_lsb;
Expand Down

0 comments on commit 366544f

Please sign in to comment.