From a41400de41b7fb0d4c7ba5bf1b451c737bef97af Mon Sep 17 00:00:00 2001 From: cgouert Date: Thu, 16 May 2024 23:25:56 -0400 Subject: [PATCH] Add tanh to primitive operations --- Cargo.toml | 1 + src/primitive_ops.rs | 55 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98214b2..dc25740 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ clap = "3.0" image = "0.23" num-integer = "0.1.46" statrs = "0.16.0" +libm = "0.2.8" csv = "1.3" debug_print = "1.0.0" dwt = "0.5.2" diff --git a/src/primitive_ops.rs b/src/primitive_ops.rs index c7a4bf8..a303d09 100644 --- a/src/primitive_ops.rs +++ b/src/primitive_ops.rs @@ -1,5 +1,6 @@ use std::time::Instant; +use libm::tanh; use ripple::common::*; use statrs::function::erf::erf; use tfhe::{ @@ -45,11 +46,11 @@ fn ct_lut_eval_quantized( wopbs_key: &WopbsKey, server_key: &ServerKey, ) -> (RadixCiphertext, f64) { - let quant_blocks = &ct.clone().into_blocks()[(nb_blocks >> 1)..nb_blocks]; + let quant_blocks = &ct.clone().into_blocks()[0..(nb_blocks >> 1)]; let quantized_ct = RadixCiphertext::from_blocks(quant_blocks.to_vec()); let quantized_lut = wopbs_key.generate_lut_radix(&quantized_ct, |x: u64| { - let x_unquantized = unquantize(x, precision, bit_width as u8); - quantize(func(x_unquantized), precision, bit_width as u8) + let x_unquantized = unquantize(x, precision, (bit_width >> 1) as u8); + quantize(func(x_unquantized), precision, (bit_width >> 1) as u8) }); let start = Instant::now(); let quant_blocks = &ct.into_blocks()[(nb_blocks >> 1)..nb_blocks]; @@ -489,6 +490,54 @@ fn main() { unquantize(lut_erf_quant, precision, bit_width as u8), ); + // 7.1 tanh(x) using LUT + let (tanh_ct, lut_time) = ct_lut_eval( + x_ct.clone(), + precision, + bit_width, + &tanh, + &wopbs_key, + &server_key, + ); + let lut_tanh: u64 = client_key.decrypt(&tanh_ct); + println!("Tanh (LUT) time: {:?}", lut_time); + + // 6.2 tanh(x) using Haar DWT LUT + let (tanh_ct_haar, dwt_time) = ct_lut_eval_haar( + x_ct.clone(), + precision, + bit_width, + nb_blocks, + &tanh, + &wopbs_key, + &server_key, + ); + let dwt_tanh: u64 = client_key.decrypt(&tanh_ct_haar); + println!("Tanh (Haar) time: {:?}", dwt_time); + + // 6.3 tanh(x) using Quantized LUT + let (tanh_ct_quant, lut_time_quant) = ct_lut_eval_quantized( + x_ct.clone(), + precision, + bit_width, + nb_blocks, + &tanh, + &wopbs_key, + &server_key, + ); + let lut_tanh_quant: u64 = client_key.decrypt(&tanh_ct_quant); + println!("Tanh (Quantized LUT) time: {:?}", lut_time_quant); + + println!( + "--- LUT: {:?}, DWT LUT: {:?}, Quant LUT: {:?}, \n--- unq: LUT: {:?}, DWT LUT: {:?}, Quant LUT {:?}", + lut_tanh, + dwt_tanh, + lut_tanh_quant, + unquantize(lut_tanh, precision, bit_width as u8), + unquantize(dwt_tanh, precision, bit_width as u8), + unquantize(lut_tanh_quant, precision, bit_width as u8), + ); + // let x_ct = client_key.encrypt(5_u64); // let x_neg_ct = client_key.encrypt(2_u64.pow(bit_width as u32)-5_u64);