Skip to content

Instantly share code, notes, and snippets.

@xu-cheng
Last active May 31, 2021 20:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xu-cheng/b2cea7111e7efc818497474970089fef to your computer and use it in GitHub Desktop.
Save xu-cheng/b2cea7111e7efc818497474970089fef to your computer and use it in GitHub Desktop.
// Ref:
// https://github.com/mobilecoinofficial/mc-oblivious/blob/master/aligned-cmov/src/cmov_impl_asm.rs
#include <cassert>
#include <cstddef>
#include <cstdint>
// cmov 64 bits data
inline __attribute__((always_inline)) void cmov_u64(const bool condition,
const uint64_t* src,
uint64_t* dst)
{
uint64_t tmp = *dst;
asm volatile(
R"(
test %1, %1
cmovnz %2, %0
)"
: "+&r"(tmp)
: "r"(condition), "rm"(*src)
: "cc");
*dst = tmp;
}
// cmov 32 bits data
inline __attribute__((always_inline)) void cmov_u32(const bool condition,
const uint32_t* src,
uint32_t* dst)
{
uint32_t tmp = *dst;
asm volatile(
R"(
test %1, %1
cmovnz %2, %0
)"
: "+&r"(tmp)
: "r"(condition), "rm"(*src)
: "cc");
*dst = tmp;
}
// cmov 16 bits data
inline __attribute__((always_inline)) void cmov_u16(const bool condition,
const uint16_t* src,
uint16_t* dst)
{
uint16_t tmp = *dst;
asm volatile(
R"(
test %1, %1
cmovnz %2, %0
)"
: "+&r"(tmp)
: "r"(condition), "rm"(*src)
: "cc");
*dst = tmp;
}
// cmov 8 bits data
inline __attribute__((always_inline)) void cmov_u8(const bool condition,
const uint8_t* src,
uint8_t* dst)
{
uint64_t tmp1 = static_cast<uint64_t>(*src);
uint64_t tmp2 = static_cast<uint64_t>(*dst);
cmov_u64(condition, &tmp1, &tmp2);
*dst = static_cast<uint8_t>(tmp2);
}
inline __attribute__((always_inline)) bool is_aligned(const void* pointer,
size_t byte_count)
{
return (uintptr_t)pointer % byte_count == 0;
}
// cmov byte array
inline __attribute__((always_inline)) void cmov_bytes(const bool condition,
const uint8_t* src,
uint8_t* dst, size_t len)
{
#ifdef __AVX2__
// move in 512 bits unit (if aligned)
if (len >= 64 && is_aligned((void*)src, 64) && is_aligned((void*)dst, 64)) {
size_t moves = len / 64;
size_t move_len = moves * 64;
size_t move_len2 = move_len;
uint64_t tmp = static_cast<uint64_t>(condition);
asm volatile(
R"(
neg %0
vmovq %0, %%xmm2
vbroadcastsd %%xmm2, %%ymm1
mov %3, %0
loop_%=:
vmovdqa -64(%1, %0), %%ymm2
vpmaskmovq %%ymm2, %%ymm1, -64(%2, %0)
vmovdqa -32(%1, %0), %%ymm3
vpmaskmovq %%ymm3, %%ymm1, -32(%2, %0)
sub $64, %0
jnz loop_%=
)"
: "+&r"(tmp)
: "r"(src), "r"(dst), "rmi"(move_len2)
: "cc", "memory", "xmm2", "ymm1", "ymm2", "ymm3");
src += move_len;
dst += move_len;
len -= move_len;
}
#endif
assert(is_aligned((void*)src, 8) && "src need to be 8 bytes aligned");
assert(is_aligned((void*)dst, 8) && "dst need to be 8 bytes aligned");
// move in 64 bits unit
if (len >= 8) {
size_t moves = len / 8;
size_t move_len = moves * 8;
uint64_t tmp = static_cast<uint64_t>(condition);
asm volatile(
R"(
neg %0
loop_%=:
mov -8(%3, %1, 8), %0
cmovc -8(%2, %1, 8), %0
mov %0, -8(%3, %1, 8)
dec %1
jnz loop_%=
)"
: "+&r"(tmp), "+&r"(moves)
: "r"(src), "r"(dst)
: "cc", "memory");
src += move_len;
dst += move_len;
len -= move_len;
}
// move in 32 bits unit
if (len >= 4) {
uint32_t tmp;
asm volatile(
R"(
test %1, %1
mov (%3), %0
cmovnz (%2), %0
mov %0, (%3)
)"
: "+&r"(tmp)
: "r"(condition), "r"(src), "r"(dst)
: "cc", "memory");
src += 4;
dst += 4;
len -= 4;
}
// move in 16 bits unit
if (len >= 2) {
uint16_t tmp;
asm volatile(
R"(
test %1, %1
mov (%3), %0
cmovnz (%2), %0
mov %0, (%3)
)"
: "+&r"(tmp)
: "r"(condition), "r"(src), "r"(dst)
: "cc", "memory");
src += 2;
dst += 2;
len -= 2;
}
// move in 8 bits unit
if (len >= 1) {
cmov_u8(condition, src, dst);
}
}
/////////////////// TEST /////////////////
#include <cstring>
#include <iomanip>
#include <iostream>
#include <string>
#include <vector>
using namespace std;
static bool FAILED = false;
struct alignas(64) Foo {
uint8_t data[4096];
Foo(uint8_t value) { memset(this->data, value, sizeof(this->data)); }
bool operator==(const Foo& other) const
{
return equal(begin(this->data), end(this->data), begin(other.data));
}
};
ostream& operator<<(ostream& out, const vector<uint8_t>& v)
{
bool flag = false;
out << "[";
for (uint8_t x : v) {
if (flag) {
cout << ", ";
} else {
flag = true;
}
out << (int)x;
}
out << "]";
return out;
}
ostream& operator<<(ostream& out, const Foo& v)
{
bool flag = false;
out << "[";
for (uint8_t x : v.data) {
if (flag) {
cout << ", ";
} else {
flag = true;
}
out << (int)x;
}
out << "]";
return out;
}
template <class T>
void check(const string& msg, const T& lhs, const T& rhs)
{
ios::fmtflags flags(cout.flags());
cout << msg << "\t";
if (lhs == rhs) {
cout << "[PASS]" << endl;
} else {
cout << "[FAIL]" << endl;
cout << hex;
cout << "left:" << lhs << endl;
cout << "right:" << rhs << endl;
}
cout.flags(flags);
}
int main()
{
{
uint8_t v1 = 0xff;
uint8_t v2 = 0xee;
cmov_u8(false, &v1, &v2);
check("cmov_u8 false", v2, uint8_t(0xee));
cmov_u8(true, &v1, &v2);
check("cmov_u8 true", v2, uint8_t(0xff));
}
{
uint16_t v1 = 0xffff;
uint16_t v2 = 0xeeee;
cmov_u16(false, &v1, &v2);
check("cmov_u16 false", v2, uint16_t(0xeeee));
cmov_u16(true, &v1, &v2);
check("cmov_u16 true", v2, uint16_t(0xffff));
}
{
uint32_t v1 = 0xffffffff;
uint32_t v2 = 0xeeeeeeee;
cmov_u32(false, &v1, &v2);
check("cmov_u32 false", v2, uint32_t(0xeeeeeeee));
cmov_u32(true, &v1, &v2);
check("cmov_u32 true", v2, uint32_t(0xffffffff));
}
{
uint64_t v1 = 0xffffffffffffffff;
uint64_t v2 = 0xeeeeeeeeeeeeeeee;
cmov_u64(false, &v1, &v2);
check("cmov_u64 false", v2, uint64_t(0xeeeeeeeeeeeeeeee));
cmov_u64(true, &v1, &v2);
check("cmov_u64 true", v2, uint64_t(0xffffffffffffffff));
}
{
for (size_t i : {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 16, 17, 64, 128, 256, 512, 1024, 2048, 3000}) {
string label("cmov_bytes " + to_string(i));
vector<uint8_t> v1(i, 0xff);
vector<uint8_t> v2(i, 0xee);
cmov_bytes(false, v1.data(), v2.data(), i);
check(label + " false", v2, vector<uint8_t>(i, 0xee));
cmov_bytes(true, v1.data(), v2.data(), i);
check(label + " true", v2, vector<uint8_t>(i, 0xff));
}
}
{
Foo v1(0xff);
Foo v2(0xee);
cmov_bytes(false, (uint8_t*)&v1, (uint8_t*)&v2, sizeof(Foo));
check("cmov_bytes 64B aligned false", v2, Foo(0xee));
cmov_bytes(true, (uint8_t*)&v1, (uint8_t*)&v2, sizeof(Foo));
check("cmov_bytes 64B aligned true", v2, Foo(0xff));
}
return FAILED ? 1 : 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment