Skip to content

Commit

Permalink
feat: lut input precision
Browse files Browse the repository at this point in the history
  • Loading branch information
nilmemo committed Apr 24, 2024
1 parent 3ac55d7 commit 3cd0f5f
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 52 deletions.
12 changes: 6 additions & 6 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ pub fn means_and_stds(dataset: &[Vec<f64>], num_features: usize) -> (Vec<f64>, V
(mins, maxs)
}

pub fn haar(table_size: u8, precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>) {
pub fn haar(table_size: u8, input_precision: u8, output_precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>) {
let max = 1 << bit_width;
let mut data = Vec::new();
for x in 0..max {
let x = unquantize(x, precision, bit_width);
let x = unquantize(x, input_precision, bit_width);
let sig = 1f64 / (1f64 + (-x).exp());
data.push(sig);
}
Expand All @@ -137,7 +137,7 @@ pub fn haar(table_size: u8, precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>
.get(0..coef_len)
.unwrap()
.iter()
.map(|x| quantize(scalar * x, precision, bit_width))
.map(|x| quantize(scalar * x, output_precision, bit_width))
.collect();
haar.rotate_right(1 << (table_size - 1));
let mask = (1 << (bit_width / 2)) - 1;
Expand Down Expand Up @@ -176,14 +176,14 @@ pub fn db2() -> (Vec<Vec<u64>>, Vec<u64>) {
(lut_lsb_vecs, lut_msb_vec)
}

pub fn quantized_table(table_size: u8, precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>) {
pub fn quantized_table(table_size: u8, input_precision: u8, output_precision: u8, bit_width: u8) -> (Vec<u64>, Vec<u64>) {
let mut data = Vec::new();
let max = 1 << (table_size);
for x in 0..max {
let x = x << (bit_width - table_size);
let xq = unquantize(x, precision, bit_width);
let xq = unquantize(x, input_precision, bit_width);
let sig = 1f64 / (1f64 + (-xq).exp());
data.push(quantize(sig, precision, bit_width));
data.push(quantize(sig, output_precision, bit_width));
}
let mask = (1 << (bit_width / 2)) - 1;
let lsb = data.clone().iter().map(|x| x & mask).collect();
Expand Down
29 changes: 14 additions & 15 deletions src/dwt_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ pub fn quantize_dataset(dataset: &Vec<Vec<f64>>, precision: u8, bit_width: u8) -
fn main() {
let bit_width = 24u8;
let precision = 8;
let table_size = 16;
let table_size = 12;

println!("Starting Haar");
let (lut_lsb, _lut_msb) = haar(table_size, precision, bit_width);
let (lut_lsb, _lut_msb) = haar(table_size, precision * 2, table_size, bit_width);
println!("{:?}", lut_lsb);
// println!("{:?}", lut_msb);
println!("End Haar");
Expand All @@ -38,30 +38,29 @@ fn main() {
// Server computation
let mut prediction = bias_int;
for (&s, &w) in sample.iter().zip(weights_int.iter()) {
println!("s: {:?}", s);
println!("weight: {:?}", w);
// println!("s: {:?}", s);
// println!("weight: {:?}", w);
prediction = add(prediction, mul(w, s, bit_width), bit_width);
println!("MAC result: {:?}", prediction);
// println!("MAC result: {:?}", prediction);
}
println!("prediction {prediction}");
let probability1 = sigmoid(prediction, 2 * precision, precision, bit_width);
prediction = trunc(prediction, bit_width, precision);
// println!("prediction {prediction}");
let probability = sigmoid(prediction, 2 * precision, table_size, bit_width);
let prediction = prediction >> (bit_width - table_size);
let probability = lut_lsb[prediction as usize];
let lut_probability = lut_lsb[prediction as usize];

println!("{probability1} {probability}");
println!("diff {:?}", probability as i64 - probability1 as i64);
// println!("{probability1} {probability}");
println!("diff {:?}", probability as i64 - lut_probability as i64);

let class = (probability > quantize(0.5, precision, bit_width)) as usize;
let class = (lut_probability > quantize(0.5, table_size, bit_width)) as usize;

// Client computation
println!("predicted {class:?}, target {target:?}");
// println!("predicted {class:?}, target {target:?}");
if class == *target {
total += 1;
}
println!();
// println!();
}
let accuracy = (total as f64 / dataset.len() as f64) * 100.0;
println!("Accuracy {accuracy}%");
println!("precision: {precision}, bit_width: {bit_width}");
println!("table size: {table_size}, precision: {precision}, bit_width: {bit_width}");
}
13 changes: 8 additions & 5 deletions src/encrypted_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use tfhe::{
};

fn main() {
println!("Encrypted Logistic Regression");

let matches = App::new("Ripple")
.about("Vanilla Encrypted Logistic Regression")
.arg(
Expand All @@ -33,13 +35,13 @@ fn main() {
.expect("Number of samples must be an integer");

// ------- Client side ------- //
let bit_width = 24u8;
let bit_width = 24;
let precision = 8;
assert!(precision <= bit_width / 2);

// Number of blocks per ciphertext
let nb_blocks = bit_width / 2;
println!("Number of blocks: {:?}", nb_blocks);
println!("Number of blocks for the radix decomposition: {:?}", nb_blocks);

let start = Instant::now();
// Generate radix keys
Expand Down Expand Up @@ -104,7 +106,7 @@ fn main() {
.take(num_samples)
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);
println!("Starting inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
Expand All @@ -125,7 +127,7 @@ fn main() {
.collect::<Vec<_>>()
} else {
let start = Instant::now();
println!("Started inference.");
println!("Starting inference.");

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in encrypted_dataset[0].iter_mut().zip(weights_int.iter()) {
Expand All @@ -147,9 +149,10 @@ fn main() {
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability: u64 = client_key.decrypt(probability);
let pr = (ptxt_probability as f64) / ((1<<precision) as f64);

let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
println!("[{}] predicted {:?}, target {:?} (prediction probability {:?})", num, class, target, pr);
if class == *target {
total += 1;
}
Expand Down
17 changes: 10 additions & 7 deletions src/encrypted_lr_dwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ fn eval_exp(x: u64, exp_map: &Vec<u64>) -> u64 {
}

fn main() {
println!("Encrypted Logistic Regression using Discrete Wavelet Transform");

let matches = App::new("Ripple")
.about("Vanilla Encrypted Logistic Regression")
.arg(
Expand All @@ -45,16 +47,16 @@ fn main() {
.expect("Number of samples must be an integer");

// ------- Client side ------- //
let bit_width = 24u8;
let bit_width = 24;
let precision = 8;
let table_size = 12;
let table_size = bit_width / 2;
assert!(precision <= bit_width / 2);

let (lut_lsb, _lut_msb) = haar(table_size, precision, bit_width);
let (lut_lsb, _lut_msb) = haar(table_size, 2 * precision, precision, bit_width);

// Number of blocks per ciphertext
let nb_blocks = bit_width >> 2;
println!("Number of blocks: {:?}", nb_blocks);
println!("Number of blocks for the radix decomposition: {:?}", nb_blocks);

let start = Instant::now();
// Generate radix keys
Expand Down Expand Up @@ -128,7 +130,7 @@ fn main() {
.take(num_samples)
.map(|(cnt, sample)| {
let start = Instant::now();
println!("Started inference #{:?}.", cnt);
println!("Starting inference #{:?}.", cnt);

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in sample.iter_mut().zip(weights_int.iter()) {
Expand All @@ -154,7 +156,7 @@ fn main() {
.collect::<Vec<_>>()
} else {
let start = Instant::now();
println!("Started inference.");
println!("Starting inference.");

let mut prediction = server_key.create_trivial_radix(bias_int, nb_blocks.into());
for (s, &weight) in encrypted_dataset[0].iter_mut().zip(weights_int.iter()) {
Expand Down Expand Up @@ -182,9 +184,10 @@ fn main() {
let mut total = 0;
for (num, (target, probability)) in targets.iter().zip(all_probabilities.iter()).enumerate() {
let ptxt_probability: u64 = client_key.decrypt(probability);
let pr = (ptxt_probability as f64) / ((1<<precision) as f64);

let class = (ptxt_probability > quantize(0.5, precision, bit_width)) as usize;
println!("[{}] predicted {:?}, target {:?}", num, class, target);
println!("[{}] predicted {:?}, target {:?} (prediction probability {:?})", num, class, target, pr);
if class == *target {
total += 1;
}
Expand Down
4 changes: 4 additions & 0 deletions src/plain_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fn main() {
println!("MAC result: {:?}", prediction);
}
let probability = sigmoid(prediction, 2 * precision, precision, bit_width);
println!("probability {probability}");
let class = (probability > quantize(0.5, precision, bit_width)) as usize;

// Client computation
Expand All @@ -43,6 +44,9 @@ fn main() {
total += 1;
}
println!();
if total == 8 {
break;
}
}

let accuracy = (total as f64 / dataset.len() as f64) * 100.0;
Expand Down
20 changes: 10 additions & 10 deletions src/sigmoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ fn main() {
let precision = 12;
let table_size = 8;

println!("Starting Haar");
let (lut_lsb, lut_msb) = quantized_table(table_size, precision, bit_width);
let (lut_haar_lsb, lut_haar_msb) = haar(table_size, precision, bit_width);
println!("Haar");
println!("Generating Lookup Tables");
let (lut_lsb, lut_msb) = quantized_table(table_size, precision, precision, bit_width);
let (lut_haar_lsb, lut_haar_msb) = haar(table_size, precision, precision, bit_width);

let mut diff_quant = Vec::new();
let mut diff_haar = Vec::new();
// let dataset: Vec<u64> = vec![0, 72, 1050, 1790, 10234, 60122, 65001, 65535];
let max = (1 << bit_width) - 1;
println!("Evaluating Sigmoid");
for x in 0..max {
let s0 = sigmoid(x, precision, precision, bit_width);
let s1 = sigmoid(
let _s1 = sigmoid(
trunc(x, bit_width, bit_width - table_size),
precision - (bit_width - table_size),
precision,
Expand All @@ -32,16 +32,16 @@ fn main() {
let s3 = s3a + (s3b << (bit_width / 2));
let diff = s0 as i64 - s3 as i64;
diff_haar.push(diff * diff);
println!("{x} {s0} {s1} {s2} {s3}");
// println!("{x} {s0} {_s1} {s2} {s3}");
}
println!(
"quantized mse {:?} ulp {:?}",
(diff_quant.iter().sum::<i64>() as f64).sqrt(),
"=== Quantized ===\nMean Squared Error {:?} \nMaximum Absolute Error {:?}",
(diff_quant.iter().sum::<i64>() as f64) / diff_quant.len() as f64,
(*diff_quant.iter().max().unwrap() as f64).sqrt()
);
println!(
"haar mse {:?} ulp {:?}",
(diff_haar.iter().sum::<i64>() as f64).sqrt(),
"=== Haar DWT ====\nMean Squared Error {:?} \nMaximum Absolute Error {:?}",
(diff_haar.iter().sum::<i64>() as f64) / diff_haar.len() as f64,
(*diff_haar.iter().max().unwrap() as f64).sqrt()
);
}
24 changes: 15 additions & 9 deletions src/sigmoid_encrypted_dwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ fn main() {
let precision = 12;
let table_size = 8;

let (lut_lsb_plain, lut_msb_plain) = haar(table_size, precision, bit_width);
let (lut_lsb_plain, lut_msb_plain) = haar(table_size, precision, precision, bit_width);

// Number of blocks per ciphertext
let nb_blocks = bit_width >> 2;
println!("Number of blocks: {:?}", nb_blocks);
let pbs_blocks = bit_width >> 2;
println!("Number of blocks: {:?}", pbs_blocks);

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks.into());
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, pbs_blocks.into());

// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
Expand All @@ -45,16 +45,22 @@ fn main() {

let dataset: Vec<u64> = vec![0, 72, 1050, 1790, 10234, 60122, 65001, 65535];
// Expected [2079, 2079, 2333, 2458, 3776, 847, 1888, 2015]
// Recieved [2143, 2143, 2396, 2519, 3794, 890, 1951, 2079]
// Received [2143, 2143, 2396, 2519, 3794, 890, 1951, 2079]

// let mut dataset = Vec::new();
// let max = 1 << 10;
// for i in 0..max {
// dataset.push(i * (1 << 6));
// }

let start = Instant::now();
let mut encrypted_dataset: Vec<_> = dataset
.par_iter() // Use par_iter() for parallel iteration
.map(|&sample| {
let mut lsb = client_key
.encrypt(sample & (1 << ((nb_blocks << 1) - 1)))
.encrypt(sample & (1 << ((pbs_blocks << 1) - 1)))
.into_blocks(); // Get LSBs
let msb = client_key.encrypt(sample >> (nb_blocks << 1)).into_blocks(); // Get MSBs
let msb = client_key.encrypt(sample >> (pbs_blocks << 1)).into_blocks(); // Get MSBs
lsb.extend(msb);
RadixCiphertext::from_blocks(lsb)
})
Expand All @@ -75,9 +81,9 @@ fn main() {
// Truncate
let mut prediction = sample.clone();
let prediction_blocks =
&prediction.into_blocks()[(nb_blocks as usize)..((nb_blocks << 1) as usize)];
&prediction.into_blocks()[(pbs_blocks as usize)..((pbs_blocks << 1) as usize)];
let prediction_msb = RadixCiphertext::from_blocks(prediction_blocks.to_vec());
let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1);
// let prediction_msb = server_key.unchecked_scalar_add(&prediction_msb, 1);
// Keyswitch and Bootstrap
prediction = wopbs_key.keyswitch_to_wopbs_params(&server_key, &prediction_msb);
let lut_lsb =
Expand Down

0 comments on commit 3cd0f5f

Please sign in to comment.