Skip to content

Commit

Permalink
feat: add encrypted multi-euclidean
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Apr 30, 2024
1 parent 779136e commit 77ff965
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 30 deletions.
12 changes: 6 additions & 6 deletions data/euclidean.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
client,server
1,2
3,4
5,6
7,8
9,10
client,server_1,server_2,server_3,server_4
1,2,3,4,5
3,4,5,6,7
5,6,8,9,1
7,8,9,1,2
9,10,1,2,3
20 changes: 12 additions & 8 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,23 @@ pub fn db2() -> (Vec<Vec<u64>>, Vec<u64>) {
(lut_lsb_vecs, lut_msb_vec)
}

pub fn read_csv_two_columns(filename: &str) -> (Vec<u32>, Vec<u32>) {
let mut x = Vec::new();
let mut y = Vec::new();

pub fn read_csv(filename: &str) -> Vec<Vec<u32>> {
let csv = File::open(filename).unwrap();
let mut reader = csv::Reader::from_reader(csv);

let num_columns = reader.headers().unwrap().len();
let mut data = vec![vec![]; num_columns];
for line in reader.deserialize() {
let res: Vec<u32> = line.expect("a CSV record");
x.push(res[0]);
y.push(res[1]);
let record: Vec<u32> = line.unwrap();
if record.len() != num_columns {
panic!("Number of columns in row does not match header");
}
for (i, &value) in record.iter().enumerate() {
data[i].push(value);
}
}
(x, y)

data
}

pub fn quantized_table(
Expand Down
7 changes: 3 additions & 4 deletions src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ fn pearson_correlation(x: &[u32], y: &[u32]) -> f64 {
}

fn main() {
let (experience, salary) = common::read_csv_two_columns("data/correlation.csv");
if experience.len() != salary.len() {
panic!("The length of the two arrays must be equal");
}
let data = common::read_csv("data/correlation.csv");
let experience = &data[0];
let salary = &data[1];

let mut salary_sorted = salary.clone();
salary_sorted.sort();
Expand Down
9 changes: 3 additions & 6 deletions src/encrypted_correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@ use tfhe::{
};

fn main() {
let (experience, salaries) = common::read_csv_two_columns("data/correlation.csv");
assert_eq!(
experience.len(),
salaries.len(),
"The length of the two arrays must be equal"
);
let data = common::read_csv("data/correlation.csv");
let experience = &data[0];
let salaries = &data[1];
let dataset_size = salaries.len() as f64;

let mut salary_sorted = salaries.clone();
Expand Down
122 changes: 116 additions & 6 deletions src/euclidean.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
use std::time::Instant;

use num_integer::Roots;
use rayon::prelude::*;
use ripple::common;
use tfhe::{
integer::{gen_keys_radix, wopbs::*, RadixCiphertext},
shortint::parameters::{
parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
},
};

/// d(x, y) = sqrt( sum((xi - yi)^2) )
fn euclidean(x: &[u32], y: &[u32]) -> f32 {
Expand All @@ -10,11 +21,110 @@ fn euclidean(x: &[u32], y: &[u32]) -> f32 {
}

fn main() {
let (x, y) = common::read_csv_two_columns("data/euclidean.csv");
let data = common::read_csv("data/euclidean.csv");
let xs = &data[0];

// ------- Client side ------- //
let bit_width = 16;

// Number of blocks per ciphertext
let nb_blocks = bit_width / 2;
println!(
"Number of blocks for the radix decomposition: {:?}",
nb_blocks
);

let start = Instant::now();
// Generate radix keys
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, nb_blocks);
// Generate key for PBS (without padding)
let wopbs_key = WopbsKey::new_wopbs_key(
&client_key,
&server_key,
&WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
);
println!(
"Key generation done in {:?} sec.",
start.elapsed().as_secs_f64()
);

let start = Instant::now();
let xs_enc: Vec<_> = xs
.par_iter() // Use par_iter() for parallel iteration
.map(|&x| client_key.encrypt(x))
.collect();
println!(
"Encryption done in {:?} sec.",
start.elapsed().as_secs_f64()
);

// ------- Server side ------- //
// TODO: Move LUT gens up here

let num_iter = 3;
assert!(
num_iter <= data.len() - 1,
"Not enough columns in CSV for that many iterations"
);

let mut sum_dists = (1..num_iter + 1)
.into_par_iter()
.map(|i| {
let ys = &data[i];

let distance = euclidean(&xs, &ys);
println!("{}) Ptxt Euclidean distance: {}", i, distance);

// Compute the encrypted euclidean distance

let start = Instant::now();
println!("{}) Starting computing Squared Euclidean distance", i);

let mut euclid_squared_enc = xs_enc
.iter()
.zip(ys.iter())
.map(|(x_enc, &y)| {
let diff = server_key.scalar_sub_parallelized(x_enc, y);
server_key.mul_parallelized(&diff, &diff)
})
.fold(
server_key.create_trivial_radix(0_u64, nb_blocks),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);
println!(
"{}) Finished computing Squared Euclidean distance in {:?} sec.",
i,
start.elapsed().as_secs_f64()
);

println!("{}) Starting computing square root", i);
let sqrt_lut = wopbs_key.generate_lut_radix(&euclid_squared_enc, |x: u64| x.sqrt());
euclid_squared_enc =
wopbs_key.keyswitch_to_wopbs_params(&server_key, &euclid_squared_enc);
let mut distance_enc = wopbs_key.wopbs(&euclid_squared_enc, &sqrt_lut);
distance_enc = wopbs_key.keyswitch_to_pbs_params(&distance_enc);
println!(
"{}) Finished computing square root in {:?} sec.",
i,
start.elapsed().as_secs_f64()
);

distance_enc
})
.collect::<Vec<_>>()
.into_iter()
.fold(
server_key.create_trivial_radix(0_u64, nb_blocks),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);

let div_lut =
wopbs_key.generate_lut_radix(&sum_dists, |x: u64| x / (num_iter as u64));
sum_dists = wopbs_key.keyswitch_to_wopbs_params(&server_key, &sum_dists);
let mut dists_mean_enc = wopbs_key.wopbs(&sum_dists, &div_lut);
dists_mean_enc = wopbs_key.keyswitch_to_pbs_params(&dists_mean_enc);

if x.len() != y.len() {
panic!("The length of the two arrays must be equal");
}
let distance = euclidean(&x, &y);
println!("Euclidean distance: {}", distance);
// ------- Client side ------- //
let mean_distance: u64 = client_key.decrypt(&dists_mean_enc);
println!("Mean of {} Euclidean distances: {}", num_iter, mean_distance);
}

0 comments on commit 77ff965

Please sign in to comment.