Created
February 28, 2020 18:00
-
-
Save ladnir/b58941dc917615dfb3dc22951588e629 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void OosNcoOtReceiver::computeProof() | |
{ | |
if (eq(mChallengeSeed, ZeroBlock)) | |
throw RTE_LOC; | |
// This AES will work as a PRNG, using AES-NI in counter mode. | |
AES aes(mChallengeSeed); | |
// the index of the AES counter. | |
u64 aesIdx(0); | |
// the index of the row that we are doing. | |
u64 k = 0; | |
// This will be used as a fast way to multiply the random challenge bits | |
// by the rows. zeroAndAllOneBlocks[0] will always be 00000.....00000, | |
// and zeroAndAllOneBlocks[1] will hold 111111.....111111. | |
// Multiplication is then just and array index and an & operation. | |
// i.e. x * block <==> block & zeroAndAllOneBlocks[x] | |
// This is so much faster than if(x) sum[l] = sum[l] ^ block | |
std::array<block, 2> zeroAndAllOneBlocks{ ZeroBlock, AllOneBlock }; | |
u64 codeSize = mT0.stride(); | |
// This will make the us send all of our input words | |
// and the complete T0 matrix. For DEBUG only | |
#ifdef OOS_CHECK_DEBUG | |
chl.send(mT0.data(), mT0.bounds()[0] * mT0.stride() * sizeof(block)); | |
chl.send(mW.data(), mW.bounds()[0] * mW.stride() * sizeof(block)); | |
std::vector<std::array<block, 2>> baseOTs(mGens.size()); | |
std::vector<u64> mBlockIdxs(mGens.size()); | |
for (u64 i = 0; i < baseOTs.size(); ++i) | |
{ | |
baseOTs[i][0] = mGens[i][0].getSeed(); | |
baseOTs[i][1] = mGens[i][1].getSeed(); | |
mBlockIdxs[i] = mGens[i][0].mBlockIdx; | |
} | |
chl.send(baseOTs); | |
chl.send(mBlockIdxs); | |
#endif | |
// this will hold out random x^(l)_i values that we compute from the seed. | |
std::vector<block> challengeBuff(mStatSecParam); | |
// since we don't want to do bit shifting, this larger array | |
// will be used to hold each bit of challengeBuff as a whole | |
// byte. See below for how we do this efficiently. | |
std::vector<block> expandedBuff(mStatSecParam * 8); | |
auto byteView = span<u8>((u8*)expandedBuff.data(), expandedBuff.size() * sizeof(block)); | |
// This will be used to compute expandedBuff | |
block mask = _mm_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); | |
// get raw pointer to this data. faster than normal indexing. | |
auto mT0Iter = mT0.begin(); | |
auto mWIter = mW.begin(); | |
// compute the index that we should stop at. We process 128 rows at a time. | |
u64 lStop = (mCorrectionIdx - mStatSecParam + 127) / 128; | |
for (u64 l = 0; l < lStop; ++l) | |
{ | |
// generate statSecParam * 128 bits using AES-NI in counter mode. | |
aes.ecbEncCounterMode(aesIdx, mStatSecParam, challengeBuff.data()); | |
aesIdx += mStatSecParam; | |
// now expand each of these bits into its own byte. This is done with the | |
// right shift instruction _mm_srai_epi16. and then we mask to get only | |
// the bottom bit. Doing the 8 times gets us each bit in its own byte. | |
for (u64 i = 0; i < mStatSecParam; ++i) | |
{ | |
expandedBuff[i * 8 + 0] = mask & _mm_srai_epi16(challengeBuff[i], 0); | |
expandedBuff[i * 8 + 1] = mask & _mm_srai_epi16(challengeBuff[i], 1); | |
expandedBuff[i * 8 + 2] = mask & _mm_srai_epi16(challengeBuff[i], 2); | |
expandedBuff[i * 8 + 3] = mask & _mm_srai_epi16(challengeBuff[i], 3); | |
expandedBuff[i * 8 + 4] = mask & _mm_srai_epi16(challengeBuff[i], 4); | |
expandedBuff[i * 8 + 5] = mask & _mm_srai_epi16(challengeBuff[i], 5); | |
expandedBuff[i * 8 + 6] = mask & _mm_srai_epi16(challengeBuff[i], 6); | |
expandedBuff[i * 8 + 7] = mask & _mm_srai_epi16(challengeBuff[i], 7); | |
} | |
// compute when we should stop of this set. | |
u64 stopIdx = std::min<u64>(mCorrectionIdx - mStatSecParam - k, u64(128)); | |
k += 128; | |
// get an integrator to the challenge bit | |
auto xIter = byteView.begin(); | |
if (codeSize == 4) | |
{ | |
// vvvvvvvvvvvv OPTIMIZED for codeword size 4 vvvvvvvvvvvv | |
for (u64 i = 0; i < stopIdx; ++i, mT0Iter += 4) | |
{ | |
// get the index of the first summation. | |
auto tSumIter = mTBuff.begin(); | |
// For this row, iterate through all statSecParam challenge | |
// bits and add the row in if they are set to 1. We process | |
// two rows at a time. | |
for (u64 j = 0; j < mStatSecParam / 2; ++j, tSumIter += 8) | |
{ | |
// get the challenge bits. | |
u8 x0 = *xIter++; | |
u8 x1 = *xIter++; | |
// dereference the challenge bits into blocks | |
// of either 000....0000 or 11111....111111 | |
block mask0 = zeroAndAllOneBlocks[x0]; | |
block mask1 = zeroAndAllOneBlocks[x1]; | |
// now add the i'th row of T0 if the bit is 1. | |
// Otherwise this is a no op. Equiv. to an if(x). | |
auto t0x0 = *(mT0Iter + 0) & mask0; | |
auto t0x1 = *(mT0Iter + 1) & mask0; | |
auto t0x2 = *(mT0Iter + 2) & mask0; | |
auto t0x3 = *(mT0Iter + 3) & mask0; | |
auto t0x4 = *(mT0Iter + 0) & mask1; | |
auto t0x5 = *(mT0Iter + 1) & mask1; | |
auto t0x6 = *(mT0Iter + 2) & mask1; | |
auto t0x7 = *(mT0Iter + 3) & mask1; | |
// add them into the running totals. | |
tSumIter[0] = tSumIter[0] ^ t0x0; | |
tSumIter[1] = tSumIter[1] ^ t0x1; | |
tSumIter[2] = tSumIter[2] ^ t0x2; | |
tSumIter[3] = tSumIter[3] ^ t0x3; | |
tSumIter[4] = tSumIter[4] ^ t0x4; | |
tSumIter[5] = tSumIter[5] ^ t0x5; | |
tSumIter[6] = tSumIter[6] ^ t0x6; | |
tSumIter[7] = tSumIter[7] ^ t0x7; | |
} | |
} | |
xIter = byteView.begin(); | |
for (u64 i = 0; i < stopIdx; ++i, ++mWIter) | |
{ | |
// now do the same but for the input words. | |
auto wSumIter = mWBuff.begin(); | |
for (u64 j = 0; j < mStatSecParam / 8; ++j, wSumIter += 8) | |
{ | |
// we processes 8 rows of words at a time. Do the | |
// same masking trick. | |
auto wx0 = (*mWIter & zeroAndAllOneBlocks[xIter[0]]); | |
auto wx1 = (*mWIter & zeroAndAllOneBlocks[xIter[1]]); | |
auto wx2 = (*mWIter & zeroAndAllOneBlocks[xIter[2]]); | |
auto wx3 = (*mWIter & zeroAndAllOneBlocks[xIter[3]]); | |
auto wx4 = (*mWIter & zeroAndAllOneBlocks[xIter[4]]); | |
auto wx5 = (*mWIter & zeroAndAllOneBlocks[xIter[5]]); | |
auto wx6 = (*mWIter & zeroAndAllOneBlocks[xIter[6]]); | |
auto wx7 = (*mWIter & zeroAndAllOneBlocks[xIter[7]]); | |
// add them into the running totals. | |
wSumIter[0] = wSumIter[0] ^ wx0; | |
wSumIter[1] = wSumIter[1] ^ wx1; | |
wSumIter[2] = wSumIter[2] ^ wx2; | |
wSumIter[3] = wSumIter[3] ^ wx3; | |
wSumIter[4] = wSumIter[4] ^ wx4; | |
wSumIter[5] = wSumIter[5] ^ wx5; | |
wSumIter[6] = wSumIter[6] ^ wx6; | |
wSumIter[7] = wSumIter[7] ^ wx7; | |
xIter += 8; | |
} | |
} | |
// ^^^^^^^^^^^^^ OPTIMIZED for codeword size 4 ^^^^^^^^^^^^^ | |
} | |
else | |
{ | |
// vvvvvvvvvvvv general codeword size vvvvvvvvvvvv | |
for (u64 i = 0; i < stopIdx; ++i, mT0Iter += codeSize) | |
{ | |
auto tSumIter = mTBuff.begin(); | |
// For this row, iterate through all statSecParam challenge | |
// bits and add the row in if they are set to 1. We process | |
// two rows at a time. | |
for (u64 j = 0; j < mStatSecParam; ++j, tSumIter += codeSize) | |
{ | |
block mask0 = zeroAndAllOneBlocks[*xIter++]; | |
for (u64 m = 0; m < codeSize; ++m) | |
{ | |
// now add the i'th row of T0 if the bit is 1. | |
// Otherwise this is a no op. Equiv. to an if(x). | |
tSumIter[m] = tSumIter[m] ^ (*(mT0Iter + m) & mask0); | |
} | |
} | |
} | |
if (mW.stride() != 1) | |
throw std::runtime_error("generalize this code vvvvvv " LOCATION); | |
xIter = byteView.begin(); | |
for (u64 i = 0; i < stopIdx; ++i, ++mWIter) | |
{ | |
auto wSumIter = mWBuff.begin(); | |
// now do the same but for the input words. | |
for (u64 j = 0; j < mStatSecParam / 8; ++j, wSumIter += 8) | |
{ | |
// we processes 8 rows of words at a time. Do the | |
// same masking trick. | |
auto wx0 = (*mWIter & zeroAndAllOneBlocks[xIter[0]]); | |
auto wx1 = (*mWIter & zeroAndAllOneBlocks[xIter[1]]); | |
auto wx2 = (*mWIter & zeroAndAllOneBlocks[xIter[2]]); | |
auto wx3 = (*mWIter & zeroAndAllOneBlocks[xIter[3]]); | |
auto wx4 = (*mWIter & zeroAndAllOneBlocks[xIter[4]]); | |
auto wx5 = (*mWIter & zeroAndAllOneBlocks[xIter[5]]); | |
auto wx6 = (*mWIter & zeroAndAllOneBlocks[xIter[6]]); | |
auto wx7 = (*mWIter & zeroAndAllOneBlocks[xIter[7]]); | |
// add them into the running totals. | |
wSumIter[0] = wSumIter[0] ^ wx0; | |
wSumIter[1] = wSumIter[1] ^ wx1; | |
wSumIter[2] = wSumIter[2] ^ wx2; | |
wSumIter[3] = wSumIter[3] ^ wx3; | |
wSumIter[4] = wSumIter[4] ^ wx4; | |
wSumIter[5] = wSumIter[5] ^ wx5; | |
wSumIter[6] = wSumIter[6] ^ wx6; | |
wSumIter[7] = wSumIter[7] ^ wx7; | |
xIter += 8; | |
} | |
} | |
// ^^^^^^^^^^^^^ general codeword size ^^^^^^^^^^^^^ | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment