Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save polkovnikov/4e80899fb20d3a977d271bff9c0ec84d to your computer and use it in GitHub Desktop.
Save polkovnikov/4e80899fb20d3a977d271bff9c0ec84d to your computer and use it in GitHub Desktop.
StackOverflow 72031323
#include <cstdint>
#include <string>
using String = std::string;
#define ASSERT_MSG(cond, msg) \
{ \
if (!(cond)) \
return Err{"Assert (" #cond ") failed at " + IntToStr(__LINE__) + "! Msg '" + Str(msg) + "'."}; \
}
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define MEM_CLEANER(obj) MemCleaner _##obj##_##__LINE__((u8 *)&obj, sizeof(obj))
#define SECURE_ARRAY(type, name, size) \
type name[size]; \
MEM_CLEANER(name);
#define ERRDIF(cond) \
{ \
if (cond) \
return Err{#cond}; \
}
#define TRYR(name, code) \
auto const res_##__LINE__ = (code); \
if (!res_##__LINE__) \
return Err{IntToStr(__LINE__) + ": " + #name "=" #code ";\n" + res_##__LINE__.Err()}; \
auto name = (res_##__LINE__).Ok();
#define TRYI(code) \
{ \
auto const res = (code); \
if (!res) \
return Err{IntToStr(__LINE__) + ": " + #code ";\n" + res.Err()}; \
}
#define TRYERR(code) \
{ \
auto const res = (code); \
if (res) \
return Err{IntToStr(__LINE__) + ": " + #code ";\n" + "Error should happen, but didn't!"}; \
}
#define CHECK(res) \
{ \
if (!res) \
return Err{res.Err()}; \
}
#define H2A(arr, bysize, hex) \
SECURE_ARRAY(u8, arr, bysize); \
TRYI(HexDecodeFit(hex, arr, bysize));
using u8 = uint8_t;
using u16 = uint16_t;
using i32 = int32_t;
using u32 = uint32_t;
using i64 = int64_t;
using u64 = uint64_t;
using SizeT = std::size_t;
using ISizeT = std::ptrdiff_t;
SizeT StrLen(void const * ptr) {
SizeT size = 0;
for (SizeT i = 0;; ++i, ++size)
if (((u8 *)ptr)[i] == 0)
break;
return size;
}
inline void MemSet(void * ptr, u8 val, SizeT size) {
// std::memset(ptr, val, size);
for (SizeT i = 0; i < size; ++i)
((u8 *)ptr)[i] = val;
}
inline void MemCpy(void * dst, void const * src, SizeT size) {
// std::memcpy(dst, src, size);
for (SizeT i = 0; i < size; ++i)
((u8 *)dst)[i] = ((u8 *)src)[i];
}
inline int MemCmp(void const * a, void const * b, u32 size) {
// return std::memcmp(a, b, size);
for (SizeT i = 0; i < size; ++i)
if (((u8 *)a)[i] < ((u8 *)b)[i])
return -1;
else if (((u8 *)a)[i] > ((u8 *)b)[i])
return 1;
return 0;
}
template <typename T>
T && Move(T & obj) {
return static_cast<T &&>(obj);
}
class MemCleaner {
public:
MemCleaner(u8 * ptr, SizeT bysize) : ptr_(ptr), bysize_(bysize) { MemSet(ptr_, 0, bysize_); }
~MemCleaner() { MemSet(ptr_, 0, bysize_); }
private:
SizeT bysize_ = 0;
u8 * ptr_ = nullptr;
};
template <class _Elem>
class initializer_list {
public:
using value_type = _Elem;
using reference = const _Elem &;
using const_reference = const _Elem &;
using size_type = size_t;
using iterator = const _Elem *;
using const_iterator = const _Elem *;
constexpr initializer_list() noexcept : _First(nullptr), _Last(nullptr) {}
constexpr initializer_list(const _Elem * _First_arg, const _Elem * _Last_arg) noexcept
: _First(_First_arg), _Last(_Last_arg) {}
constexpr const _Elem * begin() const noexcept { return _First; }
constexpr const _Elem * end() const noexcept { return _Last; }
constexpr size_t size() const noexcept { return static_cast<size_t>(_Last - _First); }
private:
const _Elem * _First;
const _Elem * _Last;
};
template <typename... Args>
struct Tuple;
template <typename A, typename B>
struct Tuple<A, B> {
A first{};
B second{};
Tuple() = default;
Tuple(Tuple const &) = default;
Tuple(Tuple &&) = default;
Tuple & operator=(Tuple const &) = default;
Tuple & operator=(Tuple &&) = default;
Tuple(A const & a, B const & b) : first(a), second(b) {}
Tuple(A && a, B const & b) : first(Move(a)), second(b) {}
Tuple(A const & a, B && b) : first(a), second(Move(b)) {}
Tuple(A && a, B && b) : first(Move(a)), second(Move(b)) {}
template <SizeT I>
constexpr auto & get() & {
if constexpr (I == 0)
return first;
else if constexpr (I == 1)
return second;
else
static_assert([] { return false; }());
}
template <SizeT I>
constexpr auto && get() && {
return Move(static_cast<Tuple &>(*this).template get<I>());
}
template <SizeT I>
constexpr auto const & get() const {
return const_cast<Tuple &>(*this).template get<I>();
}
};
namespace std {
template <class _Tuple>
struct tuple_size;
template <size_t _Index, class _Tuple>
struct tuple_element;
} // namespace std
namespace std {
template <typename A, typename B>
struct tuple_size<Tuple<A, B>> {
static SizeT constexpr value = 2;
};
template <typename A, typename B>
struct tuple_element<0, Tuple<A, B>> {
using type = A;
};
template <typename A, typename B>
struct tuple_element<1, Tuple<A, B>> {
using type = B;
};
} // namespace std
template <typename A, typename B>
auto MakeTuple(A && a, B && b) {
return Tuple<A, B>(a, b);
}
template <typename T, SizeT Size>
class Array {
public:
constexpr Array() = default;
constexpr Array(std::initializer_list<T> const & il) {
SizeT i = 0;
for (auto const & e : il)
(*this)[i++] = e;
for (SizeT j = i; j < size(); ++j)
(*this)[j] = T{};
}
constexpr Array(Array const &) = default;
constexpr Array(Array &&) = default;
constexpr Array & operator=(Array const &) = default;
constexpr Array & operator=(Array &&) = default;
constexpr T & operator[](SizeT i) { return arr_[i]; }
constexpr T const & operator[](SizeT i) const { return arr_[i]; }
constexpr SizeT size() const { return Size; }
constexpr T * data() { return &arr_[0]; }
constexpr T const * data() const { return &arr_[0]; }
constexpr T * begin() { return data(); }
constexpr T const * begin() const { return data(); }
constexpr T * end() { return data() + size(); }
constexpr T const * end() const { return data() + size(); }
template <SizeT I>
constexpr T & get() & {
return (*this)[I];
}
template <SizeT I>
constexpr T && get() && {
return Move((*this)[I]);
}
template <SizeT I>
constexpr T const & get() const {
return (*this)[I];
}
private:
T arr_[Size];
};
namespace std {
template <typename T, SizeT Size>
struct tuple_size<Array<T, Size>> {
static SizeT constexpr value = Size;
};
template <typename T, SizeT Size, SizeT Idx>
struct tuple_element<Idx, Array<T, Size>> {
using type = T;
};
} // namespace std
template <typename T>
class HeapMem {
public:
HeapMem() = default;
HeapMem(HeapMem const &) = default;
HeapMem(HeapMem &&) = default;
HeapMem & operator=(HeapMem const &) = default;
HeapMem & operator=(HeapMem && other) = default;
void ReAllocate(SizeT size) { mem_.resize(size * sizeof(T)); }
SizeT Size() const { return mem_.size() / sizeof(T); }
T * Ptr() { return (T *)&mem_[0]; }
T const * Ptr() const { return (T const *)&mem_[0]; }
T & operator[](SizeT i) { return Ptr()[i]; }
T const & operator[](SizeT i) const { return Ptr()[i]; }
private:
String mem_;
};
template <typename T>
class Vector {
public:
Vector() {}
Vector(std::initializer_list<T> const & il) {
// insert(begin(), il.begin(), il.end());
resize(il.size());
SizeT i = 0;
for (auto const & e : il)
(*this)[i++] = e;
}
Vector(SizeT size) { resize(size); }
~Vector() { resize(0); }
Vector(Vector const &) = default;
Vector(Vector && other) { *this = Move(other); }
Vector & operator=(Vector const &) = default;
Vector & operator=(Vector && other) {
data_ = Move(other.data_);
size_ = other.size_;
other.size_ = 0;
return *this;
}
T * begin() { return data(); }
T const * begin() const { return data(); }
T * end() { return data() + size(); }
T const * end() const { return data() + size(); }
SizeT size() const { return size_; }
T & operator[](SizeT i) { return data_[i]; }
T const & operator[](SizeT i) const { return data_[i]; }
T * data() { return &data_[0]; }
T const * data() const { return &data_[0]; }
template <typename ItT>
void insert(T * ptr, ItT const & begin, ItT const & end) {
SizeT const idx = ptr == this->begin() ? 0 : ptr - this->begin(), cnt = end - begin;
Insert(idx, cnt);
SizeT i = idx;
for (auto it = begin; it != end; ++it, ++i)
(*this)[i] = *it;
}
void push_back(T const & obj) { resize(size() + 1, obj); }
void resize(SizeT size, T const & val = T{}) {
if (size <= size_) {
for (SizeT i = size; i < size_; ++i)
data_[i].~T();
size_ = size;
return;
}
auto cur_size = data_.Size();
while (cur_size < size)
cur_size = cur_size == 0 ? 1 : cur_size * 2;
if (data_.Size() != cur_size) {
HeapMem<T> newm;
newm.ReAllocate(cur_size);
for (SizeT i = 0; i < size_; ++i)
new (&newm[i]) T(Move(data_[i]));
data_ = Move(newm);
}
for (SizeT i = size_; i < size; ++i)
new (&data_[i]) T();
size_ = size;
}
void Insert(SizeT idx, SizeT cnt, T const & value = T{}) {
if (cnt == 0)
return;
if (idx > size())
return;
auto const prev_size = size();
resize(size() + cnt);
for (ISizeT i = ISizeT(size()) - 1; i >= ISizeT(idx); --i)
(*this)[i + cnt] = Move((*this)[i]);
for (SizeT i = idx; i < idx + cnt; ++i)
(*this)[i] = value;
}
bool operator==(Vector const & other) const {
if (size() != other.size())
return false;
for (SizeT i = 0; i < size(); ++i)
if ((*this)[i] != other[i])
return false;
return true;
}
private:
SizeT size_ = 0;
HeapMem<T> data_;
};
class Str {
public:
Str() : data_(1) {}
Str(char const * ptr) : data_(StrLen(ptr) + 1) { MemCpy(&data_[0], ptr, size()); }
Str(char const * ptr, SizeT size) : data_(size + 1) { MemCpy(&data_[0], ptr, this->size()); }
Str(SizeT size, char val) {
data_.resize(size + 1, val);
data_[size] = 0;
}
Str(Str const &) = default;
Str(Str &&) = default;
Str & operator=(Str const &) = default;
Str & operator=(Str &&) = default;
char * begin() { return data(); }
char const * begin() const { return data(); }
char * end() { return data() + size(); }
char const * end() const { return data() + size(); }
void append(SizeT cnt, char c) { insert(size(), cnt, c); }
void insert(SizeT idx, SizeT cnt, char c) { data_.Insert(idx, cnt, c); }
SizeT size() const { return data_.size() - 1; }
void resize(SizeT size) {
data_.resize(size + 1);
data_[size] = 0;
}
char & operator[](SizeT i) { return data_[i]; }
char const & operator[](SizeT i) const { return data_[i]; }
char * data() { return &data_[0]; }
char const * data() const { return &data_[0]; }
char const * c_str() const { return data(); }
Str & operator+=(Str const & other) {
auto const prev_size = size();
data_.resize(size() + other.size() + 1);
MemCpy(&data_[prev_size], other.data(), other.size());
data_[size()] = 0;
return *this;
}
Str operator+(Str const & other) const {
Str c = *this;
c += other;
return c;
}
friend Str operator+(char const * a, Str const & b) { return Str(a) + b; }
void push_back(char c) {
data_.resize(data_.size() + 1);
data_[size() - 1] = c;
}
bool operator==(Str const & other) const { return this->data_ == other.data_; }
Str Upper() const {
Str c = *this;
for (SizeT i = 0; i < c.size(); ++i)
c[i] = 'a' <= c[i] && c[i] <= 'z' ? c[i] - 'a' + 'A' : c[i];
return c;
}
private:
Vector<char> data_;
};
struct Err {
Str err;
};
template <typename T>
class [[nodiscard]] Result {
public:
using OkT = T;
Result(T const & res) : is_ok_(true) { new (&res_) T(res); }
Result(Err const & err) : is_ok_(false) { new (&err_) Str(err.err); }
~Result() {
if (is_ok_)
res_.~T();
else
err_.~Str();
}
operator bool() const { return is_ok_; }
T & Ok() { return res_; }
T const & Ok() const { return res_; }
Str const & Err() const { return err_; }
private:
bool is_ok_ = false;
struct Dummy {};
union {
Dummy dummy_;
T res_;
Str err_;
};
};
using Error = Result<u32>;
Str IntToStr(i64 x) {
if (x == 0)
return "0";
Str r = x < 0 ? "-" : "";
if (x < 0)
x = -x;
while (x) {
r.insert(0, 1, '0' + x % 10);
x /= 10;
}
return r;
}
// https://en.wikipedia.org/wiki/SHA-2
class SHA256 {
public:
enum {
output_size = 8,
blocksize = 1,
block_size = 64,
digest_size = 32,
};
static u32 constexpr F32 = 0xFFFFFFFF;
static Array<u32, 64> constexpr c_k = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
};
static Array<u32, 8> constexpr c_h = {
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
};
static inline u32 endian_reverse32(u32 x) {
u32 step16 = x << 16 | x >> 16;
return ((step16 << 8) & 0xff00ff00) | ((step16 >> 8) & 0x00ff00ff);
}
static inline u64 endian_reverse64(uint64_t x) {
u64 step32, step16;
step32 = x << 32 | x >> 32;
step16 = (step32 & 0x0000FFFF0000FFFFULL) << 16 | (step32 & 0xFFFF0000FFFF0000ULL) >> 16;
return (step16 & 0x00FF00FF00FF00FFULL) << 8 | (step16 & 0xFF00FF00FF00FF00ULL) >> 8;
}
SHA256() {
_k = c_k;
_h = c_h;
}
static inline u32 _rotr(u32 x, u32 y) { return ((x >> y) | (x << (32 - y))) /*& F32*/; }
static inline u32 _maj(u32 x, u32 y, u32 z) { return (x & y) ^ (x & z) ^ (y & z); }
static inline u32 _ch(u32 x, u32 y, u32 z) { return (x & y) ^ ((~x) & z); }
void _compress(u8 const * block) {
MemSet(w.data(), 0, w.size() * sizeof(w[0]));
MemCpy(w.data(), block, 64);
for (size_t i = 0; i < 16; ++i)
w[i] = endian_reverse32(w[i]);
for (size_t i = 16; i < 64; ++i) {
u32 const s0 = _rotr(w[i - 15], 7) ^ _rotr(w[i - 15], 18) ^ (w[i - 15] >> 3),
s1 = _rotr(w[i - 2], 17) ^ _rotr(w[i - 2], 19) ^ (w[i - 2] >> 10);
w[i] = (w[i - 16] + s0 + w[i - 7] + s1) /*& F32*/;
}
auto h2 = _h;
auto & [a, b, c, d, e, f, g, h] = h2;
for (size_t i = 0; i < 64; ++i) {
u32 const s0 = _rotr(a, 2) ^ _rotr(a, 13) ^ _rotr(a, 22), t2 = s0 + _maj(a, b, c),
s1 = _rotr(e, 6) ^ _rotr(e, 11) ^ _rotr(e, 25), t1 = h + s1 + _ch(e, f, g) + _k[i] + w[i];
h = g;
g = f;
f = e;
e = (d + t1) /*& F32*/;
d = c;
c = b;
b = a;
a = (t1 + t2) /*& F32*/;
}
for (size_t i = 0; i < _h.size(); ++i)
_h[i] = (_h[i] + h2[i]) /*& F32*/;
}
Error update(u8 const * m, size_t size) {
ASSERT(!fin_);
if (size == 0)
return 0;
_counter += size;
size_t const port = std::min(size_t(64 - cache_size), size);
MemCpy(_cache.data() + cache_size, m, port);
cache_size += port;
m += port;
size -= port;
if (cache_size < 64)
return 0;
_compress(_cache.data());
cache_size = 0;
while (size >= 64) {
_compress(m);
m += 64;
size -= 64;
}
MemCpy(_cache.data(), m, size);
cache_size = size;
return 0;
}
static auto _pad(u64 msglen) {
size_t mdi = msglen & 0x3F;
u64 length = endian_reverse64(msglen << 3);
size_t padlen = mdi < 56 ? 55 - mdi : 119 - mdi;
Array<u8, 128> r = {0x80};
*(u64 *)(r.data() + 1 + padlen) = length; // MemCpy(r.data() + 1 + padlen, &length, 8);
return MakeTuple(Move(r), 1 + padlen + 8);
}
Result<Array<u8, 32>> digest() {
if (!fin_) {
auto [pad, padl] = _pad(_counter);
TRYI(update(pad.data(), padl));
for (size_t i = 0; i < output_size; ++i)
*(u32 *)&dig_[i * 4] = endian_reverse32(_h[i]);
fin_ = true;
}
return dig_;
}
Result<Str> hexdigest() {
static char const tab[] = "0123456789abcdef";
TRYR(dig, digest());
Str r;
for (auto e : dig) {
r.append(1, tab[e >> 4]);
r.append(1, tab[e & 0xF]);
}
return r;
}
private:
bool fin_ = false;
u64 _counter = 0;
Array<u8, 64> _cache;
size_t cache_size = 0;
Array<u32, 64> _k;
Array<u32, 8> _h;
Array<u32, 64> w;
Array<u8, 32> dig_;
};
Error TestSha256() {
Vector<Tuple<Str, Str>> tests = {
{"", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"},
{"a", "ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"},
{"abc", "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"},
{"message digest", "f7846f55cf23e14eebeab5b4e1550cad5b509e3348fbc4efa3a1413d393cb650"},
{"abcdefghijklmnopqrstuvwxyz", "71c480df93d6ae2f1efad1447c66c9525e316218cf51fc8d9ed832f2daf18b73"},
{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789",
"db4bfcbd4da0cd85a60c3c37d3fbd8805c77f15fc6b1fdfe614ee0a7c8fdb4c0"},
{"12345678901234567890123456789012345678901234567890123456789"
"012345678901234567890",
"f371bc4a311f2b009eef952dd83ca80e2b60026c8e935592d0f9c308453c813e"},
};
for (auto const & [msg, sig] : tests) {
auto m = SHA256();
TRYI(m.update((uint8_t *)msg.data(), msg.size()));
TRYR(hexdigest, m.hexdigest());
TRYERR(m.update((uint8_t *)msg.data(), msg.size())); // Generate error
ASSERT_MSG(hexdigest == sig, "msg '" + msg + "' m.hexdigest() " + hexdigest + " sig " + sig);
}
return 0;
}
struct HashInterface {
virtual u32 GetBlockSize() const = 0;
virtual u32 GetHashSize() const = 0;
virtual Error Start() = 0;
virtual Error Process(u8 const * data, u32 size) = 0;
virtual Error Stop(u8 * hash) = 0;
virtual ~HashInterface() {}
};
class HashInterface_Sha256 : public HashInterface {
public:
virtual u32 GetBlockSize() const { return 64; }
virtual u32 GetHashSize() const { return 32; }
virtual Error Start() {
hash_ = SHA256{};
return 0;
}
virtual Error Process(u8 const * data, u32 size) {
TRYI(hash_.update(data, size));
return 0;
}
virtual Error Stop(u8 * hash) {
TRYR(dig, hash_.digest());
MemCpy(hash, dig.data(), sizeof(dig));
return 0;
}
private:
SHA256 hash_;
};
// https://en.wikipedia.org/wiki/Hash-based_message_authentication_code
class HMAC {
public:
enum { c_max_block_size = 128 };
HMAC(HashInterface & hasher);
Error Start(u8 const * key, u32 key_size);
Error Process(u8 const * data, u32 size);
Error Stop(u8 * mac);
u32 GetMACSize() const;
private:
HashInterface & hasher_;
u8 o_key_[c_max_block_size];
u8 i_key_[c_max_block_size];
};
inline u32 Min(u32 a, u32 b) { return a < b ? a : b; }
inline void Xor(u8 * dst, u8 const * src, u32 size) {
for (; size > 0; --size) {
*dst ^= *src;
++dst;
++src;
}
}
HMAC::HMAC(HashInterface & hasher) : hasher_(hasher) {
MemSet(i_key_, 0, sizeof(i_key_));
MemSet(o_key_, 0, sizeof(o_key_));
}
Error HMAC::Start(u8 const * key, u32 key_size) {
ERRDIF(hasher_.GetHashSize() > hasher_.GetBlockSize());
ERRDIF(hasher_.GetBlockSize() > c_max_block_size);
MemSet(i_key_, 0x36, sizeof(i_key_));
MemSet(o_key_, 0x5C, sizeof(o_key_));
if (key_size > hasher_.GetBlockSize()) {
TRYI(hasher_.Start());
TRYI(hasher_.Process(key, key_size));
SECURE_ARRAY(u8, hash, c_max_block_size);
TRYI(hasher_.Stop(hash));
Xor(i_key_, hash, hasher_.GetHashSize());
Xor(o_key_, hash, hasher_.GetHashSize());
} else {
Xor(i_key_, key, key_size);
Xor(o_key_, key, key_size);
}
TRYI(hasher_.Start());
TRYI(hasher_.Process(i_key_, hasher_.GetBlockSize()));
return 0;
}
Error HMAC::Process(u8 const * data, u32 size) {
TRYI(hasher_.Process(data, size));
return 0;
}
Error HMAC::Stop(u8 * mac) {
SECURE_ARRAY(u8, hash, c_max_block_size);
TRYI(hasher_.Stop(hash));
TRYI(hasher_.Start());
TRYI(hasher_.Process(o_key_, hasher_.GetBlockSize()));
TRYI(hasher_.Process(hash, hasher_.GetHashSize()));
TRYI(hasher_.Stop(mac));
return 0;
}
u32 HMAC::GetMACSize() const { return hasher_.GetHashSize(); }
Str ToHex(void const * data, size_t size) {
static char const tab[] = "0123456789ABCDEF";
Str r;
for (size_t i = 0; i < size; ++i) {
r.append(1, tab[((u8 *)data)[i] >> 4]);
r.append(1, tab[((u8 *)data)[i] & 0xF]);
}
return r;
}
inline u8 HexChrToInt(char h) {
if ('0' <= h && h <= '9') {
return h - '0';
} else if ('A' <= h && h <= 'F') {
return h - 'A' + 10;
} else if ('a' <= h && h <= 'f') {
return h - 'a' + 10;
} else {
return 0;
}
}
Error HexDecodeFit(char const * hex_str, u8 * dst, u32 bysize, bool skip_bad = true) {
size_t len = StrLen(hex_str);
u32 shift = 8;
u8 res_prim = 0;
u32 pos = 0;
for (size_t i = 0; i < len; ++i) {
if (('0' <= hex_str[i] && hex_str[i] <= '9') || ('A' <= hex_str[i] && hex_str[i] <= 'F') ||
('a' <= hex_str[i] && hex_str[i] <= 'f')) {
shift -= 4;
res_prim |= u8(HexChrToInt(hex_str[i]) << shift);
if (shift == 0 || i == len - 1) {
ERRDIF(pos >= bysize);
dst[pos] = res_prim;
res_prim = 0;
++pos;
shift = 8;
}
} else {
ERRDIF(!skip_bad);
}
}
ERRDIF(pos < bysize);
return 0;
}
Error HmacSha256(void const * key, SizeT key_bysize, void const * data, SizeT data_bysize, u8 hash[32]) {
HashInterface_Sha256 hi_sha256;
HMAC hmac(hi_sha256);
TRYI(hmac.Start((u8 *)key, key_bysize));
TRYI(hmac.Process((u8 *)data, data_bysize));
TRYI(hmac.Stop(hash));
return 0;
}
Error TestHmacSha256() {
H2A(key, 32, "603deb1015ca71be2b73aef0857d7781 1f352c073b6108d72d9810a30914dff4");
H2A(data, 67,
"6bc1bee22e409f96e93d7e117393172a ae2d8a571e03ac9c9eb76fac45af8e51 30c81c46a35ce411e5fbc1191a0a52ef "
"f69f2445df4f9b17ad2b417be66c3710 123456");
H2A(hash_ref, 32, "E8867BBCD0D34DF56885446C2D6639C1FBB362BE3ECDCBB41D6CE0EB74C01DB8");
SECURE_ARRAY(u8, hash, 32);
TRYI(HmacSha256(key, sizeof(key), data, sizeof(data), hash));
ASSERT_MSG(MemCmp(hash, hash_ref, sizeof(hash)) == 0,
"hash " + ToHex(hash, sizeof(hash)) + " hash_ref " + ToHex(hash_ref, sizeof(hash_ref)));
return 0;
}
Error TestHmacSha256v2Single(Str const & key, Str const & data, Str const & hash_hex) {
SECURE_ARRAY(u8, hash, 32);
TRYI(HmacSha256(key.data(), key.size(), data.data(), data.size(), hash));
ASSERT_MSG(ToHex(hash, sizeof(hash)) == hash_hex.Upper(),
"actual " + ToHex(hash, sizeof(hash)) + " ref " + hash_hex.Upper());
return 0;
}
Error TestHmacSha256v2() {
TRYI(TestHmacSha256v2Single("key", "The quick brown fox jumps over the lazy dog",
"f7bc83f430538424b13298e6aa6fb143ef4d59a14946175997479dbc2d1a3cd8"));
TRYI(
TestHmacSha256v2Single("The quick brown fox jumps over the lazy dogThe quick brown fox jumps over the lazy dog",
"message", "5597b93a2843078cbb0c920ae41dfe20f1685e10c67e423c11ab91adfc319d12"));
return 0;
}
// https://stackoverflow.com/a/41094722/941531
Str Base64Encode(void const * src0, size_t len) {
unsigned char const * src = (unsigned char *)src0;
static const unsigned char base64_table[65] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
unsigned char *out = 0, *pos = 0;
const unsigned char *end = 0, *in = 0;
size_t olen = 0;
olen = 4 * ((len + 2) / 3); /* 3-byte blocks to 4-byte */
if (olen < len)
return Str(); /* integer overflow */
Str outStr;
outStr.resize(olen);
out = (unsigned char *)&outStr[0];
end = src + len;
in = src;
pos = out;
while (end - in >= 3) {
*pos++ = base64_table[in[0] >> 2];
*pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)];
*pos++ = base64_table[((in[1] & 0x0f) << 2) | (in[2] >> 6)];
*pos++ = base64_table[in[2] & 0x3f];
in += 3;
}
if (end - in) {
*pos++ = base64_table[in[0] >> 2];
if (end - in == 1) {
*pos++ = base64_table[(in[0] & 0x03) << 4];
*pos++ = '=';
} else {
*pos++ = base64_table[((in[0] & 0x03) << 4) | (in[1] >> 4)];
*pos++ = base64_table[(in[1] & 0x0f) << 2];
}
*pos++ = '=';
}
return Str(outStr.data(), outStr.size());
}
Result<Str> Base64Decode(void const * data0, const size_t len0) {
Str s((char *)data0, len0), r;
for (auto c : s)
if (!(c == ' ' || c == '\n' || c == '\r' || c == '\t')) {
ASSERT_MSG(('A' <= c && c <= 'Z') || ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') ||
(c == '+' || c == '/' || c == '='),
"Wrong base64 char '" + Str(1, c) + "' (0x" + ToHex(&c, 1) + ")!");
r.append(1, c);
}
void const * data = r.data();
size_t len = r.size();
static const int B64index[256] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 62, 63, 62, 62, 63, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 63, 0, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
};
unsigned char * p = (unsigned char *)data;
int pad = len > 0 && (len % 4 || p[len - 1] == '=');
const size_t L = ((len + 3) / 4 - pad) * 4;
Str str(L / 4 * 3 + pad, '\0');
for (size_t i = 0, j = 0; i < L; i += 4) {
int n = B64index[p[i]] << 18 | B64index[p[i + 1]] << 12 | B64index[p[i + 2]] << 6 | B64index[p[i + 3]];
str[j++] = n >> 16;
str[j++] = n >> 8 & 0xFF;
str[j++] = n & 0xFF;
}
if (pad) {
int n = B64index[p[L]] << 18 | B64index[p[L + 1]] << 12;
str[str.size() - 1] = n >> 16;
if (len > L + 2 && p[L + 2] != '=') {
n |= B64index[p[L + 2]] << 6;
str.push_back(n >> 8 & 0xFF);
}
}
auto const res = Str(str.data(), str.size()), encoded = Base64Encode(str.data(), str.size());
ASSERT_MSG(r == encoded, "input " + r + " != re-encoded " + encoded);
return res;
}
Error TestBase64Single(Str const & inp, Str const & out) {
auto const data = Base64Encode(inp.data(), inp.size());
ASSERT_MSG(data == out, "actual " + data + " ref " + out);
TRYR(decoded, Base64Decode(out.data(), out.size()));
ASSERT_MSG(decoded == inp, "actual " + decoded + " ref " + inp);
return 0;
}
Error TestBase64() {
TRYI(TestBase64Single("", ""));
TRYI(TestBase64Single("f", "Zg=="));
TRYI(TestBase64Single("fo", "Zm8="));
TRYI(TestBase64Single("foo", "Zm9v"));
TRYI(TestBase64Single("foob", "Zm9vYg=="));
TRYI(TestBase64Single("fooba", "Zm9vYmE="));
TRYI(TestBase64Single("foobar", "Zm9vYmFy"));
{
Str inp = "KKKKKKKKK";
TRYERR(Base64Decode(inp.data(), inp.size()));
}
{
Str inp = "KKKKKKKK";
TRYI(Base64Decode(inp.data(), inp.size()));
}
return 0;
}
Result<Str> SignSha256(Str const & key_base64, Str const & data) {
TRYR(key, Base64Decode(key_base64.data(), key_base64.size()));
HashInterface_Sha256 hi_sha256;
HMAC hmac(hi_sha256);
TRYI(hmac.Start((u8 *)key.data(), key.size()));
TRYI(hmac.Process((u8 *)data.data(), data.size()));
SECURE_ARRAY(u8, hash, 32);
TRYI(hmac.Stop(hash));
return Base64Encode(hash, sizeof(hash));
}
Error TestSignSha256Single(Str const & inp, Str const & data, Str const & sign) {
TRYR(res, SignSha256(inp, data));
ASSERT_MSG(res == sign, "actual " + res + " ref " + sign);
return 0;
}
Error TestSignSha256() {
TRYI(TestSignSha256Single("KKKKKKKK", "RRRRRRRRRRR", "2d5oEcRhy+0wV2iNqOji6N8i93QH8I0KCJA0sg3TVfw="));
TRYERR(TestSignSha256Single("KKKKKKKKK", "RRRRRRRRRRR", "2d5oEcRhy+0wV2iNqOji6N8i93QH8I0KCJA0sg3TVfw="));
TRYI(TestSignSha256Single("Zm9vYmE=", "RRRRRRRRRRR", "3R74OjlpMx7iypFGmEP/nAzNrbVs/h3k1PwSV4+r6LA="));
TRYI(TestSignSha256Single("Zg==", "QQQrt", "/EqzQo3lF0MLLwRbXDjXF7yQaGiiZ4E3aFG9ABv92ZM="));
return 0;
}
Error Test() {
TRYI(TestSha256());
TRYI(TestHmacSha256());
TRYI(TestHmacSha256v2());
TRYI(TestBase64());
TRYI(TestSignSha256());
return 0;
}
#include <iostream>
int main() {
auto const res = Test();
std::cout << (res ? "All OK" : res.Err().data()) << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment