Skip to content

Instantly share code, notes, and snippets.

@ladnir
Created February 28, 2020 18:00
Show Gist options
  • Save ladnir/b58941dc917615dfb3dc22951588e629 to your computer and use it in GitHub Desktop.
Save ladnir/b58941dc917615dfb3dc22951588e629 to your computer and use it in GitHub Desktop.
void OosNcoOtReceiver::computeProof()
{
if (eq(mChallengeSeed, ZeroBlock))
throw RTE_LOC;
// This AES will work as a PRNG, using AES-NI in counter mode.
AES aes(mChallengeSeed);
// the index of the AES counter.
u64 aesIdx(0);
// the index of the row that we are doing.
u64 k = 0;
// This will be used as a fast way to multiply the random challenge bits
// by the rows. zeroAndAllOneBlocks[0] will always be 00000.....00000,
// and zeroAndAllOneBlocks[1] will hold 111111.....111111.
// Multiplication is then just and array index and an & operation.
// i.e. x * block <==> block & zeroAndAllOneBlocks[x]
// This is so much faster than if(x) sum[l] = sum[l] ^ block
std::array<block, 2> zeroAndAllOneBlocks{ ZeroBlock, AllOneBlock };
u64 codeSize = mT0.stride();
// This will make the us send all of our input words
// and the complete T0 matrix. For DEBUG only
#ifdef OOS_CHECK_DEBUG
chl.send(mT0.data(), mT0.bounds()[0] * mT0.stride() * sizeof(block));
chl.send(mW.data(), mW.bounds()[0] * mW.stride() * sizeof(block));
std::vector<std::array<block, 2>> baseOTs(mGens.size());
std::vector<u64> mBlockIdxs(mGens.size());
for (u64 i = 0; i < baseOTs.size(); ++i)
{
baseOTs[i][0] = mGens[i][0].getSeed();
baseOTs[i][1] = mGens[i][1].getSeed();
mBlockIdxs[i] = mGens[i][0].mBlockIdx;
}
chl.send(baseOTs);
chl.send(mBlockIdxs);
#endif
// this will hold out random x^(l)_i values that we compute from the seed.
std::vector<block> challengeBuff(mStatSecParam);
// since we don't want to do bit shifting, this larger array
// will be used to hold each bit of challengeBuff as a whole
// byte. See below for how we do this efficiently.
std::vector<block> expandedBuff(mStatSecParam * 8);
auto byteView = span<u8>((u8*)expandedBuff.data(), expandedBuff.size() * sizeof(block));
// This will be used to compute expandedBuff
block mask = _mm_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
// get raw pointer to this data. faster than normal indexing.
auto mT0Iter = mT0.begin();
auto mWIter = mW.begin();
// compute the index that we should stop at. We process 128 rows at a time.
u64 lStop = (mCorrectionIdx - mStatSecParam + 127) / 128;
for (u64 l = 0; l < lStop; ++l)
{
// generate statSecParam * 128 bits using AES-NI in counter mode.
aes.ecbEncCounterMode(aesIdx, mStatSecParam, challengeBuff.data());
aesIdx += mStatSecParam;
// now expand each of these bits into its own byte. This is done with the
// right shift instruction _mm_srai_epi16. and then we mask to get only
// the bottom bit. Doing the 8 times gets us each bit in its own byte.
for (u64 i = 0; i < mStatSecParam; ++i)
{
expandedBuff[i * 8 + 0] = mask & _mm_srai_epi16(challengeBuff[i], 0);
expandedBuff[i * 8 + 1] = mask & _mm_srai_epi16(challengeBuff[i], 1);
expandedBuff[i * 8 + 2] = mask & _mm_srai_epi16(challengeBuff[i], 2);
expandedBuff[i * 8 + 3] = mask & _mm_srai_epi16(challengeBuff[i], 3);
expandedBuff[i * 8 + 4] = mask & _mm_srai_epi16(challengeBuff[i], 4);
expandedBuff[i * 8 + 5] = mask & _mm_srai_epi16(challengeBuff[i], 5);
expandedBuff[i * 8 + 6] = mask & _mm_srai_epi16(challengeBuff[i], 6);
expandedBuff[i * 8 + 7] = mask & _mm_srai_epi16(challengeBuff[i], 7);
}
// compute when we should stop of this set.
u64 stopIdx = std::min<u64>(mCorrectionIdx - mStatSecParam - k, u64(128));
k += 128;
// get an integrator to the challenge bit
auto xIter = byteView.begin();
if (codeSize == 4)
{
// vvvvvvvvvvvv OPTIMIZED for codeword size 4 vvvvvvvvvvvv
for (u64 i = 0; i < stopIdx; ++i, mT0Iter += 4)
{
// get the index of the first summation.
auto tSumIter = mTBuff.begin();
// For this row, iterate through all statSecParam challenge
// bits and add the row in if they are set to 1. We process
// two rows at a time.
for (u64 j = 0; j < mStatSecParam / 2; ++j, tSumIter += 8)
{
// get the challenge bits.
u8 x0 = *xIter++;
u8 x1 = *xIter++;
// dereference the challenge bits into blocks
// of either 000....0000 or 11111....111111
block mask0 = zeroAndAllOneBlocks[x0];
block mask1 = zeroAndAllOneBlocks[x1];
// now add the i'th row of T0 if the bit is 1.
// Otherwise this is a no op. Equiv. to an if(x).
auto t0x0 = *(mT0Iter + 0) & mask0;
auto t0x1 = *(mT0Iter + 1) & mask0;
auto t0x2 = *(mT0Iter + 2) & mask0;
auto t0x3 = *(mT0Iter + 3) & mask0;
auto t0x4 = *(mT0Iter + 0) & mask1;
auto t0x5 = *(mT0Iter + 1) & mask1;
auto t0x6 = *(mT0Iter + 2) & mask1;
auto t0x7 = *(mT0Iter + 3) & mask1;
// add them into the running totals.
tSumIter[0] = tSumIter[0] ^ t0x0;
tSumIter[1] = tSumIter[1] ^ t0x1;
tSumIter[2] = tSumIter[2] ^ t0x2;
tSumIter[3] = tSumIter[3] ^ t0x3;
tSumIter[4] = tSumIter[4] ^ t0x4;
tSumIter[5] = tSumIter[5] ^ t0x5;
tSumIter[6] = tSumIter[6] ^ t0x6;
tSumIter[7] = tSumIter[7] ^ t0x7;
}
}
xIter = byteView.begin();
for (u64 i = 0; i < stopIdx; ++i, ++mWIter)
{
// now do the same but for the input words.
auto wSumIter = mWBuff.begin();
for (u64 j = 0; j < mStatSecParam / 8; ++j, wSumIter += 8)
{
// we processes 8 rows of words at a time. Do the
// same masking trick.
auto wx0 = (*mWIter & zeroAndAllOneBlocks[xIter[0]]);
auto wx1 = (*mWIter & zeroAndAllOneBlocks[xIter[1]]);
auto wx2 = (*mWIter & zeroAndAllOneBlocks[xIter[2]]);
auto wx3 = (*mWIter & zeroAndAllOneBlocks[xIter[3]]);
auto wx4 = (*mWIter & zeroAndAllOneBlocks[xIter[4]]);
auto wx5 = (*mWIter & zeroAndAllOneBlocks[xIter[5]]);
auto wx6 = (*mWIter & zeroAndAllOneBlocks[xIter[6]]);
auto wx7 = (*mWIter & zeroAndAllOneBlocks[xIter[7]]);
// add them into the running totals.
wSumIter[0] = wSumIter[0] ^ wx0;
wSumIter[1] = wSumIter[1] ^ wx1;
wSumIter[2] = wSumIter[2] ^ wx2;
wSumIter[3] = wSumIter[3] ^ wx3;
wSumIter[4] = wSumIter[4] ^ wx4;
wSumIter[5] = wSumIter[5] ^ wx5;
wSumIter[6] = wSumIter[6] ^ wx6;
wSumIter[7] = wSumIter[7] ^ wx7;
xIter += 8;
}
}
// ^^^^^^^^^^^^^ OPTIMIZED for codeword size 4 ^^^^^^^^^^^^^
}
else
{
// vvvvvvvvvvvv general codeword size vvvvvvvvvvvv
for (u64 i = 0; i < stopIdx; ++i, mT0Iter += codeSize)
{
auto tSumIter = mTBuff.begin();
// For this row, iterate through all statSecParam challenge
// bits and add the row in if they are set to 1. We process
// two rows at a time.
for (u64 j = 0; j < mStatSecParam; ++j, tSumIter += codeSize)
{
block mask0 = zeroAndAllOneBlocks[*xIter++];
for (u64 m = 0; m < codeSize; ++m)
{
// now add the i'th row of T0 if the bit is 1.
// Otherwise this is a no op. Equiv. to an if(x).
tSumIter[m] = tSumIter[m] ^ (*(mT0Iter + m) & mask0);
}
}
}
if (mW.stride() != 1)
throw std::runtime_error("generalize this code vvvvvv " LOCATION);
xIter = byteView.begin();
for (u64 i = 0; i < stopIdx; ++i, ++mWIter)
{
auto wSumIter = mWBuff.begin();
// now do the same but for the input words.
for (u64 j = 0; j < mStatSecParam / 8; ++j, wSumIter += 8)
{
// we processes 8 rows of words at a time. Do the
// same masking trick.
auto wx0 = (*mWIter & zeroAndAllOneBlocks[xIter[0]]);
auto wx1 = (*mWIter & zeroAndAllOneBlocks[xIter[1]]);
auto wx2 = (*mWIter & zeroAndAllOneBlocks[xIter[2]]);
auto wx3 = (*mWIter & zeroAndAllOneBlocks[xIter[3]]);
auto wx4 = (*mWIter & zeroAndAllOneBlocks[xIter[4]]);
auto wx5 = (*mWIter & zeroAndAllOneBlocks[xIter[5]]);
auto wx6 = (*mWIter & zeroAndAllOneBlocks[xIter[6]]);
auto wx7 = (*mWIter & zeroAndAllOneBlocks[xIter[7]]);
// add them into the running totals.
wSumIter[0] = wSumIter[0] ^ wx0;
wSumIter[1] = wSumIter[1] ^ wx1;
wSumIter[2] = wSumIter[2] ^ wx2;
wSumIter[3] = wSumIter[3] ^ wx3;
wSumIter[4] = wSumIter[4] ^ wx4;
wSumIter[5] = wSumIter[5] ^ wx5;
wSumIter[6] = wSumIter[6] ^ wx6;
wSumIter[7] = wSumIter[7] ^ wx7;
xIter += 8;
}
}
// ^^^^^^^^^^^^^ general codeword size ^^^^^^^^^^^^^
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment