Skip to content

Commit

Permalink
Adjust parameterization for Haar-DWT correlation coefficient
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed Jun 3, 2024
1 parent 7bd7bf8 commit d1eae79
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions src/correlation_haar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn main() {

// ------- Client side ------- //
let bit_width = 16;
let precision = 6;
let precision = 0;

// Number of blocks per ciphertext
let nb_blocks = bit_width / 2;
Expand Down Expand Up @@ -67,7 +67,7 @@ fn main() {
.iter()
.map(|&exp| ((exp as f64) - experience_mean).powi(2))
.sum();
let experience_stddev = experience_variance.sqrt();
let experience_stddev = quantize(experience_variance.sqrt(), precision, bit_width as u8);

// Offline: LUT genaration is offline cost.
let lut_gen_start = Instant::now();
Expand All @@ -86,11 +86,7 @@ fn main() {
bit_width as u8,
bit_width as u8,
&|x: f64| {
if x.abs() < 0.05 {
1.0 // avoid division with zero error.
} else {
scale as f64 / (x.sqrt() * experience_stddev)
}
scale as f64 / (x.sqrt() * experience_stddev as f64)
},
);
let haar_lsb_lut_sqrt = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb));
Expand All @@ -100,7 +96,7 @@ fn main() {
precision,
bit_width as u8,
bit_width as u8,
&|x: f64| x / dataset_size,
&|x: f64| x / dataset_size
);
let haar_lsb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_lsb));
let haar_msb_lut_div = wopbs_key.generate_lut_radix(&dummy, |x: u64| eval_lut(x, &haar_msb));
Expand All @@ -126,7 +122,6 @@ fn main() {
&haar_lsb_lut_div,
&haar_msb_lut_div,
);

// Cov = Sum_i^n (salary_i - mean(salary))(experience_i - mean(experience))
let covariance = encrypted_salaries
.iter()
Expand All @@ -151,7 +146,6 @@ fn main() {
server_key.create_trivial_radix(0_u64, nb_blocks),
|acc: RadixCiphertext, diff| server_key.add_parallelized(&acc, &diff),
);

// sigma_salary (or stddev) = sqrt(var_salary)
// println!("salaries_variance_enc degree: {:?}", salaries_variance_enc.blocks()[0].degree);
let salaries_stddev_enc = ct_lut_eval_haar_no_gen(
Expand All @@ -172,6 +166,5 @@ fn main() {
// ------- Client side ------- //
let correlation: u64 = client_key.decrypt(&correlation_enc);
let correlation_final: f64 = unquantize(correlation, precision, bit_width as u8);

println!("Correlation: {}", correlation_final / (scale as f64));
}

0 comments on commit d1eae79

Please sign in to comment.