Skip to content

Instantly share code, notes, and snippets.

@pathaine
Last active September 3, 2019 21:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pathaine/24842f38a0862dc07a43e6c2605d74ca to your computer and use it in GitHub Desktop.
Save pathaine/24842f38a0862dc07a43e6c2605d74ca to your computer and use it in GitHub Desktop.
/*N is signal size (number of samples), sampleRate in kHz (num samples per ms), frameLength in ms, shiftLength in ms*/
vector<vector<Ciphertext<DCRTPoly>>> real_discrete_fourier_transform(int N, int sampleRate, int frameLength, int shiftLength, CryptoContext<DCRTPoly> cc, LPPublicKey<DCRTPoly> publicKey, CryptoFractions encSignal, LPPrivateKey<DCRTPoly> sk) {
int boundN, boundM;
if (N % 2 == 0) {
boundM = N / 2 - 1;
boundN = N / 2;
}
else {
boundM = (N - 1) / 2;
boundN = (N - 1) / 2;
}
vector<int64_t> zeros = { 0 };
vector<int64_t> one = { 1 };
for (int i = 0; i < frameLength - 1; i += 1) {
zeros.push_back(0);
one.push_back(0);
}
Plaintext zerosPT1 = cc->MakePackedPlaintext(zeros);
Plaintext zerosPT2 = cc->MakePackedPlaintext(zeros);
Plaintext onePTRe = cc->MakePackedPlaintext(one);
Plaintext onePTIm = cc->MakePackedPlaintext(one);
Ciphertext<DCRTPoly> zerosCT1;
Ciphertext<DCRTPoly> zerosCT2;
Ciphertext<DCRTPoly> oneCTRe;
Ciphertext<DCRTPoly> oneCTIm;
zerosCT1 = cc->Encrypt(publicKey, zerosPT1);
zerosCT2 = cc->Encrypt(publicKey, zerosPT2);
oneCTRe = cc->Encrypt(publicKey, onePTRe);
oneCTIm = cc->Encrypt(publicKey, onePTIm);
int numSamplesInFrame = sampleRate * frameLength;
int numSamplesInShift = sampleRate * shiftLength;
unsigned long batchSize = upper_power_of_two(numSamplesInFrame);
Plaintext cosines;
Plaintext sines;
Ciphertext<DCRTPoly> innerProductRe;
Ciphertext<DCRTPoly> innerProductIm;
Ciphertext<DCRTPoly> resultRe;
Ciphertext<DCRTPoly> resultReSquared;
Ciphertext<DCRTPoly> resultIm;
Ciphertext<DCRTPoly> resultImSquared;
Ciphertext<DCRTPoly> resultRDFT;
vector<vector<Ciphertext<DCRTPoly>>> resultsRe = vector<vector<Ciphertext<DCRTPoly>>>(int(ceil(float(N) / float(numSamplesInShift))));
vector<vector<Ciphertext<DCRTPoly>>> resultsIm = vector<vector<Ciphertext<DCRTPoly>>>(int(ceil(float(N) / float(numSamplesInShift))));
vector<vector<Ciphertext<DCRTPoly>>> RDFT = vector<vector<Ciphertext<DCRTPoly>>>(int(ceil(float(N) / float(numSamplesInShift))));
int atFrame = 0;
resultsIm[atFrame].push_back(zerosCT1);
for (int i = 0; i <= N; i += numSamplesInShift) {
for (int n = 0; n <= boundN; n++) {
// CALCULATE DCT
cosines = calculate_cosines(boundN, n, N, i, i + numSamplesInFrame, cc, publicKey);
innerProductRe = cc->EvalInnerProduct(encSignal.encryptedNumerators, cosines, batchSize);
resultRe = cc->EvalMult(innerProductRe, oneCTRe);
resultsRe[atFrame].push_back(resultRe);
resultReSquared = cc->EvalMult(resultRe, resultRe);
//CALCULATE DST
if ((n + 1) <= boundM) {
sines = calculate_sines(boundM, n + 1, N, i, i + numSamplesInFrame, cc, publicKey);
innerProductIm = cc->EvalInnerProduct(encSignal.encryptedNumerators, sines, batchSize);
resultIm = cc->EvalMult(innerProductIm, oneCTIm);
resultsIm[atFrame].push_back(resultIm);
resultImSquared = cc->EvalMult(resultIm, resultIm);
}
else if ((n + 1) == boundM+1) {
resultsIm[atFrame].push_back(zerosCT2);
}
resultRDFT = cc->EvalAdd(resultReSquared, resultImSquared);
RDFT[atFrame].push_back(resultRDFT);
}
atFrame += 1;
}
return RDFT;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment