Skip to content

Commit

Permalink
Add tanh to primitive operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cgouert committed May 17, 2024
1 parent 0fa4cce commit a41400d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ clap = "3.0"
image = "0.23"
num-integer = "0.1.46"
statrs = "0.16.0"
libm = "0.2.8"
csv = "1.3"
debug_print = "1.0.0"
dwt = "0.5.2"
Expand Down
55 changes: 52 additions & 3 deletions src/primitive_ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::time::Instant;

use libm::tanh;
use ripple::common::*;
use statrs::function::erf::erf;
use tfhe::{
Expand Down Expand Up @@ -45,11 +46,11 @@ fn ct_lut_eval_quantized(
wopbs_key: &WopbsKey,
server_key: &ServerKey,
) -> (RadixCiphertext, f64) {
let quant_blocks = &ct.clone().into_blocks()[(nb_blocks >> 1)..nb_blocks];
let quant_blocks = &ct.clone().into_blocks()[0..(nb_blocks >> 1)];
let quantized_ct = RadixCiphertext::from_blocks(quant_blocks.to_vec());
let quantized_lut = wopbs_key.generate_lut_radix(&quantized_ct, |x: u64| {
let x_unquantized = unquantize(x, precision, bit_width as u8);
quantize(func(x_unquantized), precision, bit_width as u8)
let x_unquantized = unquantize(x, precision, (bit_width >> 1) as u8);
quantize(func(x_unquantized), precision, (bit_width >> 1) as u8)
});
let start = Instant::now();
let quant_blocks = &ct.into_blocks()[(nb_blocks >> 1)..nb_blocks];
Expand Down Expand Up @@ -489,6 +490,54 @@ fn main() {
unquantize(lut_erf_quant, precision, bit_width as u8),
);

// 7.1 tanh(x) using LUT
let (tanh_ct, lut_time) = ct_lut_eval(
x_ct.clone(),
precision,
bit_width,
&tanh,
&wopbs_key,
&server_key,
);
let lut_tanh: u64 = client_key.decrypt(&tanh_ct);
println!("Tanh (LUT) time: {:?}", lut_time);

// 6.2 tanh(x) using Haar DWT LUT
let (tanh_ct_haar, dwt_time) = ct_lut_eval_haar(
x_ct.clone(),
precision,
bit_width,
nb_blocks,
&tanh,
&wopbs_key,
&server_key,
);
let dwt_tanh: u64 = client_key.decrypt(&tanh_ct_haar);
println!("Tanh (Haar) time: {:?}", dwt_time);

// 6.3 tanh(x) using Quantized LUT
let (tanh_ct_quant, lut_time_quant) = ct_lut_eval_quantized(
x_ct.clone(),
precision,
bit_width,
nb_blocks,
&tanh,
&wopbs_key,
&server_key,
);
let lut_tanh_quant: u64 = client_key.decrypt(&tanh_ct_quant);
println!("Tanh (Quantized LUT) time: {:?}", lut_time_quant);

println!(
"--- LUT: {:?}, DWT LUT: {:?}, Quant LUT: {:?}, \n--- unq: LUT: {:?}, DWT LUT: {:?}, Quant LUT {:?}",
lut_tanh,
dwt_tanh,
lut_tanh_quant,
unquantize(lut_tanh, precision, bit_width as u8),
unquantize(dwt_tanh, precision, bit_width as u8),
unquantize(lut_tanh_quant, precision, bit_width as u8),
);

// let x_ct = client_key.encrypt(5_u64);
// let x_neg_ct = client_key.encrypt(2_u64.pow(bit_width as u32)-5_u64);

Expand Down

0 comments on commit a41400d

Please sign in to comment.