Skip to content

Instantly share code, notes, and snippets.

@fjarri
Created January 26, 2019 11:06
Show Gist options
  • Save fjarri/aa2e785dbee6ff66801fef4dada97675 to your computer and use it in GitHub Desktop.
Save fjarri/aa2e785dbee6ff66801fef4dada97675 to your computer and use it in GitHub Desktop.
Multi-GPU
from threading import Thread
from queue import Queue
import random
import nufhe
class MyThread:
def __init__(self, target, args=()):
self.return_queue = Queue()
self.target = target
self.thread = Thread(target=self._target_wrapper, args=args)
def _target_wrapper(self, *args):
res = self.target(*args)
self.return_queue.put(res)
def start(self):
self.thread.start()
def join(self):
ret_val = self.return_queue.get()
self.thread.join()
return ret_val
def worker(platform_device_id, cloud_key_cpu, ciphertext1_cpu, ciphertext2_cpu):
ctx = nufhe.Context(api='CUDA', platform_device_id=platform_device_id)
cloud_key = ctx.load_cloud_key(cloud_key_cpu)
ciphertext1 = ctx.load_ciphertext(ciphertext1_cpu)
ciphertext2 = ctx.load_ciphertext(ciphertext2_cpu)
vm = ctx.make_virtual_machine(cloud_key)
result = vm.gate_nand(ciphertext1, ciphertext2)
result_cpu = result.dumps()
nufhe.computation_cache.clear_computation_cache()
return result_cpu
if __name__ == '__main__':
size = 32
bits1 = [random.choice([False, True]) for i in range(size)]
bits2 = [random.choice([False, True]) for i in range(size)]
reference = [not (b1 and b2) for b1, b2 in zip(bits1, bits2)]
devices = nufhe.get_devices(api='CUDA')
ctx = nufhe.Context()
secret_key, cloud_key = ctx.make_key_pair()
ciphertext1 = ctx.encrypt(secret_key, bits1)
ciphertext2 = ctx.encrypt(secret_key, bits2)
ck = cloud_key.dumps()
ct1_part1 = ciphertext1[:size//2].dumps()
ct1_part2 = ciphertext1[size//2:].dumps()
ct2_part1 = ciphertext2[:size//2].dumps()
ct2_part2 = ciphertext2[size//2:].dumps()
t1 = MyThread(target=worker, args=(devices[0], ck, ct1_part1, ct2_part1))
t1.start()
t2 = MyThread(target=worker, args=(devices[1], ck, ct1_part2, ct2_part2))
t2.start()
result_part1 = t1.join()
result_part2 = t2.join()
result_part1 = ctx.load_ciphertext(result_part1)
result_part2 = ctx.load_ciphertext(result_part2)
r1 = ctx.decrypt(secret_key, result_part1)
r2 = ctx.decrypt(secret_key, result_part2)
assert r1.tolist() + r2.tolist() == reference
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment