Last active
April 12, 2016 03:05
-
-
Save daniman/1f8d75791746f1f00d15bee8be55389c to your computer and use it in GitHub Desktop.
Code to conduct the side channel attack for 6.857's pset 3 problem 2b.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import gmpy2 | |
import sys | |
import math | |
import json | |
from math import sqrt | |
n = 1024 | |
R = 2**(n//2) | |
SERVER_URL = "http://6857rsa.csail.mit.edu:8080" | |
TEAM = "daniman_ivonchyk_vossm_rchipman" | |
# | |
# Dependency Notes: | |
# This file requires python3 and the gmp library. It also requires the pip | |
# module gmpy2. | |
# | |
# First check google for instructions on how to install python3 and gmp | |
# for your operating system. They are available with apt-get on Linux and | |
# brew on mac. | |
# | |
# Next make sure pip is installed (using python3 not python): | |
# https://pip.pypa.io/en/stable/installing/ | |
# | |
# Finally install gmpy2 | |
# python3 -m pip install gmpy2 | |
# | |
# If you are using windows, make sure you have Python3.2 or Python3.3 | |
# installed, then run the appropriate installer from | |
# https://pypi.python.org/pypi/gmpy2 to install gmpy2. | |
# | |
# Feel free to post on piazza for assistance! | |
# | |
# | |
# RSA Server API | |
# POST /decrypt | |
# json request body | |
# team: String, comma-separated list of team member kerberos names (or | |
# a practice team name previously generated by the server) | |
# ciphertext: String, hex-encoded | |
# no_n: bool, optional, if true the server omits the modulus in the response | |
# setting this will save some network bandwidth if you already | |
# know the modulus for this team strings's key | |
# json response body | |
# modulus: String, hex-encoded, present if no_n was not true | |
# time: integer, units of time the decryption took (use this, not the | |
# real time the response takes to arrive) | |
# | |
# POST /guess | |
# json request body | |
# team: String | |
# q: String, hex-encoded, the smaller of (p, q) | |
# json response body | |
# correct: bool, whether the guess is correct | |
# | |
# GET /gen_practice | |
# no request body | |
# json response body | |
# team: String, a random team string generated by the server | |
# p: String, hex-encoded, the larger of the two secret primes | |
# q: String, hex-encoded, the smaller of the two secret primes | |
# | |
def main(): | |
# first make a dummy request to find the public modulus for our team | |
initial_request = {"team": TEAM, "ciphertext": "00"*(n//8)} | |
r = requests.post(SERVER_URL + "/decrypt", data=json.dumps(initial_request)) | |
try: | |
N = int(r.json()["modulus"], 16) | |
except: | |
print(r.text) | |
sys.exit(1) | |
# compute R^{-1}_N | |
Rinv = gmpy2.invert(R, N) | |
print r.json() | |
# Start with a "guess" of 0, and analyze the zero-one gap, updating our | |
# guess each time. Repeat this for the (512-16) most significant bits of q | |
g = 0 | |
cluster_0 = [1400] | |
cluster_1 = [0] | |
for i in range(512-16): | |
gap = compute_gap(g, i, Rinv, 512, N) | |
middle = (float(sum(cluster_0))/len(cluster_0) + float(sum(cluster_1))/len(cluster_1)) / 2 | |
if gap < middle: | |
print 'Q[' + str(i) + '] = 1; Gap: ' + str(gap) | |
g += 2 ** (512-i) | |
cluster_1.append(gap) | |
else: | |
print 'Q[' + str(i) + '] = 0; Gap: ' + str(gap) | |
cluster_0.append(gap) | |
g /= 2 # correct for the fact that 2^512 is a 513-bit number | |
print "##########################################################" | |
print g | |
print "##########################################################" | |
skip_to = 49100 # skip ahead because we've already found q via brute force and know | |
# that the right key is at i=49155 | |
# brute-force last 16 bits | |
g += skip_to | |
for i in range(2**16 - skip_to): | |
q = g + i | |
if not ((q%2==0) or (q%3==0) or (q%5==0) or (q%7==0) or (q%13==0) or (q%17==0) or \ | |
(q%19==0) or (q%23==0) or (q%29==0) or (q%31==0) or (q%37==0) or (q%41==0) or \ | |
(q%47==0) or (q%53==0)): | |
prime = is_prime(q) | |
print 'Try: ' + str(i) + '; ' + str(prime) | |
if prime: | |
submit_guess(q) | |
# compute the gap for a given guess `g` (assuming the top `i` bits are | |
# correct) | |
def compute_gap(g, i, Rinv, n, N): | |
bit = 512 - i | |
g_hi = g + 2 ** bit | |
u_g_hi_times = [] | |
u_g_times = [] | |
n_its = 50 | |
for j in xrange(n_its): | |
u_g = (g * Rinv) % N | |
u_g_hi = (g_hi * Rinv) % N | |
u_g_times.append(time_decrypt(u_g)) | |
u_g_hi_times.append(time_decrypt(u_g_hi)) | |
g += 1 | |
g_hi += 1 | |
return abs(sum(u_g_times) - sum(u_g_hi_times))/float(n_its) | |
# hex-encode a ciphertext and send it to the server for decryption | |
# returns the simulated time the decryption took | |
def time_decrypt(ctxt): | |
padded_ctxt = ctxt_to_padded_hex_string(ctxt, n) | |
req = {"team": TEAM, "ciphertext": padded_ctxt, "no_n": True} | |
r = requests.post(SERVER_URL + "/decrypt", data=json.dumps(req)) | |
try: | |
return r.json()["time"] | |
except: | |
print(r.text) | |
# converts a gmpy integer into a hex string front-zero padded to n bits | |
def ctxt_to_padded_hex_string(ctxt, n): | |
h = ctxt.digits(16) | |
h = "0"*max(n//4 - len(h), 0) + h | |
return h | |
# requests a random practice key from the server | |
def gen_practice_key(): | |
r = requests.get(SERVER_URL + "/gen_practice") | |
try: | |
json = r.json() | |
return {"team": json["team"], "p": int(json["p"], 16), "q": int(json["q"], 16)} | |
except: | |
print(r.text) | |
sys.exit(1) | |
# hex-encodes q and sends it to the server, printing the result | |
def submit_guess(q): | |
# convert q to hex and remove '0x' at beginning | |
print 'submitting request: ' + str(hex(q)[2:-1]) | |
data = {"team": TEAM, "q": hex(q)[2:-1]} | |
r = requests.post(SERVER_URL + "/guess", data=json.dumps(data)) | |
print r.json() | |
if r and r.json()['correct'] == True: | |
sys.exit("----------------------------------------------------\nRecovered Q: " + \ | |
str(hex(q)[2:-1]) + "\n----------------------------------------------------") | |
def mrange(start, stop, step): | |
stop = 5000000 # arbitrarily chosen upper bound to limit computation | |
while start < stop: | |
yield start | |
start += step | |
def is_prime(num): | |
if num == 2: | |
return True | |
if (num < 2) or (num % 2 == 0): | |
return False | |
return all(num % i for i in mrange(3, int(sqrt(num)) + 1, 2)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment