Created
November 9, 2021 13:30
-
-
Save thdecn/9046a9a160a7180a6c058c73e867c6ba to your computer and use it in GitHub Desktop.
4-bit adder with concrete-boolean
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate concrete_boolean; | |
use concrete_boolean::gen_keys; | |
use concrete_boolean::server_key::ServerKey; | |
use concrete_boolean::ciphertext::Ciphertext; | |
use rand::Rng; | |
fn half_adder(server_key: &ServerKey, ct_a: &Ciphertext, ct_b: &Ciphertext) | |
-> (Ciphertext, Ciphertext) | |
{ | |
// sum = a xor b | |
let ct_sum = server_key.xor(ct_a, ct_b); | |
// carry = a and b | |
let ct_carry = server_key.and(ct_a, ct_b); | |
// return sum and carry | |
(ct_sum, ct_carry) | |
} | |
fn full_adder(server_key: &ServerKey, ct_a: &Ciphertext, ct_b: &Ciphertext, ct_c: &Ciphertext) | |
-> (Ciphertext, Ciphertext) | |
{ | |
// a xor b | |
let a_xor_b = server_key.xor(ct_a, ct_b); | |
// a and b | |
let a_and_b = server_key.and(ct_a, ct_b); | |
// (a xor b) and c | |
let a_xor_b_and_c = server_key.and(&a_xor_b, ct_c); | |
// sum = (a xor b) xor c | |
let ct_sum = server_key.xor(&a_xor_b, ct_c); | |
// carry = (a and b) or ((a xor b) and c) | |
let ct_carry = server_key.or(&a_and_b, &a_xor_b_and_c); | |
// return sum and carry | |
(ct_sum, ct_carry) | |
} | |
pub fn adder_4bit(server_key: &ServerKey, | |
ct_a: &(Ciphertext, Ciphertext, Ciphertext, Ciphertext), | |
ct_b: &(Ciphertext, Ciphertext, Ciphertext, Ciphertext) ) | |
-> (Ciphertext, Ciphertext, Ciphertext, Ciphertext, Ciphertext) | |
{ | |
// First Half Adder | |
let (ct_sum_0, ct_c0) = half_adder(&server_key, &ct_a.0, &ct_b.0); | |
// First Full Adder | |
let (ct_sum_1, ct_c1) = full_adder(&server_key, &ct_a.1, &ct_b.1, &ct_c0); | |
// Second Full Adder | |
let (ct_sum_2, ct_c2) = full_adder(&server_key, &ct_a.2, &ct_b.2, &ct_c1); | |
// Third Full Adder | |
let (ct_sum_3, ct_carry) = full_adder(&server_key, &ct_a.3, &ct_b.3, &ct_c2); | |
// Return Tupple of Carry and 4-bit Sum | |
(ct_carry, ct_sum_3, ct_sum_2, ct_sum_1, ct_sum_0) | |
} | |
fn main() { | |
// [CLIENT SIDE] | |
// Instantiate Random Number Generator | |
let mut rng = rand::thread_rng(); | |
// Generate 4 random Boolean values for our a plaintext | |
let a_pt = rng.gen::<(bool, bool, bool, bool)>(); | |
println!("A: {:?}", a_pt); | |
// Generate 4 random Boolean values for our b plaintext | |
let b_pt = rng.gen::<(bool, bool, bool, bool)>(); | |
println!("B: {:?}", b_pt); | |
// Generate a set of client/server keys, using the default parameters | |
let (client_key, server_key) = gen_keys(); | |
// Use the client secret key to encrypt plaintext a to ciphertext a | |
let a3_ct = client_key.encrypt(a_pt.3); | |
let a2_ct = client_key.encrypt(a_pt.2); | |
let a1_ct = client_key.encrypt(a_pt.1); | |
let a0_ct = client_key.encrypt(a_pt.0); | |
let a_ct = (a3_ct, a2_ct, a1_ct, a0_ct); | |
// Use the client secret key to encrypt plaintext b to ciphertext b | |
let b3_ct = client_key.encrypt(b_pt.3); | |
let b2_ct = client_key.encrypt(b_pt.2); | |
let b1_ct = client_key.encrypt(b_pt.1); | |
let b0_ct = client_key.encrypt(b_pt.0); | |
let b_ct = (b3_ct, b2_ct, b1_ct, b0_ct); | |
// [SERVER SIDE] | |
// Use the server public key to add the a and b ciphertexts | |
let (carry_ct, sum3_ct, sum2_ct, sum1_ct, sum0_ct) = adder_4bit(&server_key, &a_ct, &b_ct); | |
// [CLIENT SIDE] | |
// Use the client secret key to decrypt the ciphertext of the sum | |
let sum3_pt = client_key.decrypt(&sum3_ct); | |
let sum2_pt = client_key.decrypt(&sum2_ct); | |
let sum1_pt = client_key.decrypt(&sum1_ct); | |
let sum0_pt = client_key.decrypt(&sum0_ct); | |
let sum_pt = (sum3_pt, sum2_pt, sum1_pt, sum0_pt); | |
println!("Sum: {:?}", sum_pt); | |
// Use the client secret key to decrypt the ciphertext of the carry | |
let carry_pt = client_key.decrypt(&carry_ct); | |
println!("Carry: {:?}", carry_pt); | |
// Convert Boolean tupples to integers and check result | |
// Most Significant Bit in position 0, Least Significant Bit in position 3 | |
let mut a = 0; | |
if a_pt.0 { a += 8 } | |
if a_pt.1 { a += 4 } | |
if a_pt.2 { a += 2 } | |
if a_pt.3 { a += 1 } | |
println!("A: {}", a); | |
let mut b = 0; | |
if b_pt.0 { b += 8 } | |
if b_pt.1 { b += 4 } | |
if b_pt.2 { b += 2 } | |
if b_pt.3 { b += 1 } | |
println!("B: {}", b); | |
let mut sum = 0; | |
if sum_pt.0 { sum += 8 } | |
if sum_pt.1 { sum += 4 } | |
if sum_pt.2 { sum += 2 } | |
if sum_pt.3 { sum += 1 } | |
println!("Sum: {}", sum); | |
assert_eq!(sum, (a+b) % 16); | |
assert_eq!(carry_ct, (a+b) >= 16); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment