Created
April 27, 2024 19:16
-
-
Save iscgar/7ca116f64737d7df69a72623bec90b33 to your computer and use it in GitHub Desktop.
Compile-time AES S-Box and inverse S-Box generation in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// A dirty hack to generate the AES S-Box at compile-time. | |
// Not optimised. Please do not use. | |
// Requires C++14 for std::integer_sequence | |
#include <cstddef> | |
#include <cstdint> | |
#include <type_traits> | |
#include <utility> | |
template<std::uintmax_t Value> | |
struct BitSum : public std::integral_constant<std::size_t, (Value & 1) + BitSum<(Value >> 1)>::value> {}; | |
template<> | |
struct BitSum<0> : public std::integral_constant<std::size_t, 0> {}; | |
template<std::uintmax_t Value> | |
struct RequiredBits : public std::integral_constant<std::size_t, 1 + RequiredBits<(Value >> 1)>::value> {}; | |
template<> | |
struct RequiredBits<0> : public std::integral_constant<std::size_t, 0> {}; | |
template<class T> | |
struct BitSize : public std::integral_constant<std::size_t, sizeof(T) * BitSum<static_cast<unsigned char>(~0)>::value> {}; | |
template<std::size_t Bits> | |
struct Uint | |
{ | |
static_assert(Bits > 0, ""); | |
using least = typename Uint<Bits + 1>::least; | |
}; | |
template<> struct Uint<8> { using least = unsigned char; }; | |
template<> struct Uint<16> { using least = unsigned short; }; | |
template<> struct Uint<32> | |
{ | |
using least = std::conditional< | |
(BitSize<unsigned int>::value >= 32), unsigned int, unsigned long | |
>::type; | |
}; | |
template<> struct Uint<64> { using least = std::uint_least64_t; }; | |
template<class List, template<typename> class Predicate> | |
struct All; | |
template<template<typename...> class List, template<typename> class Predicate> | |
struct All<List<>, Predicate> : std::integral_constant<bool, true> {}; | |
#if defined(__cplusplus) && __cplusplus >= 201703L | |
template<template<typename...> class List, class ...T, template<typename> class Predicate> | |
struct All<List<T...>, Predicate> : std::integral_constant<bool, (Predicate<T>::value && ...)> {}; | |
#else | |
template<template<typename...> class List, class Head, class ...T, template<typename> class Predicate> | |
struct All<List<Head, T...>, Predicate> : std::integral_constant<bool, (Predicate<Head>::value && All<List<T...>, Predicate>::value)> {}; | |
#endif | |
template<class List, template<typename> class Predicate> | |
struct Any; | |
template<template<typename...> class List, template<typename> class Predicate> | |
struct Any<List<>, Predicate> : std::integral_constant<bool, false> {}; | |
#if defined(__cplusplus) && __cplusplus >= 201703L | |
template<template<typename...> class List, class ...T, template<typename> class Predicate> | |
struct Any<List<T...>, Predicate> : std::integral_constant<bool, (Predicate<T>::value || ...)> {}; | |
#else | |
template<template<typename...> class List, class Head, class ...T, template<typename> class Predicate> | |
struct Any<List<Head, T...>, Predicate> : std::integral_constant<bool, (Predicate<Head>::value || Any<List<T...>, Predicate>::value)> {}; | |
#endif | |
template<class List, class T> | |
struct Append; | |
template<template<typename...> class List, class ...T, class U> | |
struct Append<List<T...>, U> | |
{ | |
using type = List<T..., U>; | |
}; | |
template<class ...Lists> | |
struct Concat; | |
template<template<typename...> class List, class ...A, class ...B> | |
struct Concat<List<A...>, List<B...>> | |
{ | |
using type = List<A..., B...>; | |
}; | |
template<template<typename...> class List, class ...A, class ...B, class ...C> | |
struct Concat<List<A...>, List<B...>, List<C...>> | |
{ | |
using type = List<A..., B..., C...>; | |
}; | |
template<class List, template<typename> class Mapper> | |
struct Map; | |
template<template<typename...> class List, class ...A, template<typename> class Mapper> | |
struct Map<List<A...>, Mapper> | |
{ | |
using type = List<typename Mapper<A>::type...>; | |
}; | |
template<class List, template<typename, typename> class Folder, class Init> | |
struct Fold; | |
template<template<typename...> class List, template<typename, typename> class Folder, class Init> | |
struct Fold<List<>, Folder, Init> | |
{ | |
using type = Init; | |
}; | |
template<template<typename...> class List, class Head, class ...A, template<typename, typename> class Folder, class Init> | |
struct Fold<List<Head, A...>, Folder, Init> | |
{ | |
using type = typename Fold< | |
List<A...>, Folder, typename Folder<Init, Head>::type>::type; | |
}; | |
template<class From, template<typename...> class To> | |
struct Transform; | |
template<template<typename...> class From, class ...A, template<typename...> class To> | |
struct Transform<From<A...>, To> | |
{ | |
using type = To<A...>; | |
}; | |
template<class List, template<typename> class Predicate> | |
struct Filter; | |
template<template<typename...> class List, template<typename> class Predicate> | |
struct Filter<List<>, Predicate> | |
{ | |
using type = List<>; | |
}; | |
template<template<typename...> class List, class Head, class ...T, template<typename> class Predicate> | |
struct Filter<List<Head, T...>, Predicate> | |
{ | |
using type = typename Concat< | |
typename std::conditional<Predicate<Head>::value, List<Head>, List<>>::type, | |
typename Filter<List<T...>, Predicate>::type | |
>::type; | |
}; | |
template<class A, class B> | |
struct Pair | |
{ | |
using first = A; | |
using second = B; | |
}; | |
template<class A, class B> | |
struct Zip; | |
template<template<typename...> class List, class ...A, class ...B> | |
struct Zip<List<A...>, List<B...>> | |
{ | |
using type = List<Pair<A, B>...>; | |
}; | |
template<class T, template<typename, typename> class BinPredicate> | |
struct BinaryToUnaryPredicate | |
{ | |
template<class U> | |
using UnaryValueFirst = BinPredicate<U, T>; | |
template<class U> | |
using UnaryValueSecond = BinPredicate<T, U>; | |
}; | |
template<template<typename, typename> class Comparator> | |
struct ComparatorPredicate | |
{ | |
template<class A, class B> | |
using Less = std::integral_constant<bool, (Comparator<A, B>::value < 0)>; | |
template<class A, class B> | |
using LessEqual = std::integral_constant<bool, (Comparator<A, B>::value <= 0)>; | |
template<class A, class B> | |
using Equal = std::integral_constant<bool, (Comparator<A, B>::value == 0)>; | |
template<class A, class B> | |
using Greater = std::integral_constant<bool, (Comparator<A, B>::value > 0)>; | |
template<class A, class B> | |
using GreaterEqual = std::integral_constant<bool, (Comparator<A, B>::value >= 0)>; | |
}; | |
template<class List, template<typename, typename> class Comparator> | |
class QuickSort; | |
template<template<typename...> class List, template<typename, typename> class Comparator> | |
class QuickSort<List<>, Comparator> | |
{ | |
public: | |
using type = List<>; | |
}; | |
template<template<typename...> class List, class Pivot, template<typename, typename> class Comparator> | |
class QuickSort<List<Pivot>, Comparator> | |
{ | |
public: | |
using type = List<Pivot>; | |
}; | |
template<template<typename...> class List, class Pivot, class Head, class ...T, template<typename, typename> class Comparator> | |
class QuickSort<List<Pivot, Head, T...>, Comparator> | |
{ | |
template<template<typename, typename> class Predicate> | |
using Partition = Filter<List<Head, T...>, BinaryToUnaryPredicate<Pivot, Predicate>::template UnaryValueFirst>; | |
using LowerPart = typename Partition<ComparatorPredicate<Comparator>::template Less>::type; | |
using UpperPart = typename Partition<ComparatorPredicate<Comparator>::template GreaterEqual>::type; | |
public: | |
using type = typename Concat< | |
typename QuickSort<LowerPart, Comparator>::type, | |
List<Pivot>, | |
typename QuickSort<UpperPart, Comparator>::type | |
>::type; | |
}; | |
template<class T, class U> | |
struct IntegralComparator; | |
template<class T, T V> | |
struct IntegralComparator<std::integral_constant<T, V>, std::integral_constant<T, V>> : | |
std::integral_constant<int, 0> {}; | |
template<class T, T A, T B> | |
struct IntegralComparator<std::integral_constant<T, A>, std::integral_constant<T, B>> : | |
std::integral_constant<int, (A < B ? -1 : 1)> {}; | |
template<class FromSeq, template<typename...> class ToSeq> | |
struct IntegerSequenceConverter; | |
template<template<typename T, T...> class Seq, class T, T ...Vals, template<typename...> class ToSeq> | |
struct IntegerSequenceConverter<Seq<T, Vals...>, ToSeq> | |
{ | |
using type = ToSeq<std::integral_constant<T, Vals>...>; | |
}; | |
template<class Pair> | |
struct SwapPair; | |
template<class A, class B> | |
struct SwapPair<Pair<A, B>> | |
{ | |
using type = Pair<B, A>; | |
}; | |
template<class ...T> struct TypeList {}; | |
template<std::size_t Bits, typename Uint<Bits>::least ...Vals> | |
struct Perm | |
{ | |
public: | |
using Element = typename Uint<Bits>::least; | |
private: | |
static constexpr std::size_t N = std::size_t(1) << Bits; | |
static_assert(sizeof...(Vals) == N, ""); | |
using ValueSeq = TypeList<std::integral_constant<Element, Vals>...>; | |
template<class V> | |
using ValueEnforcer = std::integral_constant<bool, (V::value < N)>; | |
static_assert(All<ValueSeq, ValueEnforcer>::value, ""); | |
using IndexSeq = typename IntegerSequenceConverter< | |
std::make_index_sequence<sizeof...(Vals)>, | |
TypeList | |
>::type; | |
template<class A, class B> | |
using XorFolder = std::integral_constant<Element, (A::value ^ B::value)>; | |
template<class A, class B> | |
using SumFolder = std::integral_constant<std::uintmax_t, (A::value + B::value)>; | |
static_assert( | |
Fold<IndexSeq, XorFolder, std::integral_constant<Element, 0> | |
>::type::value == Fold<ValueSeq, XorFolder, std::integral_constant<Element, 0> | |
>::type::value, ""); | |
static_assert( | |
Fold<IndexSeq, SumFolder, std::integral_constant<std::uintmax_t, 0> | |
>::type::value == Fold<ValueSeq, SumFolder, std::integral_constant<std::uintmax_t, 0> | |
>::type::value, ""); | |
public: | |
static inline constexpr std::size_t size() { return N; } | |
static constexpr const Element P[sizeof...(Vals)] = { | |
Vals... | |
}; | |
}; | |
template<std::size_t Bits, typename Uint<Bits>::least ...Vals> | |
constexpr const typename Perm<Bits, Vals...>::Element Perm<Bits, Vals...>::P[sizeof...(Vals)]; | |
template<template<typename, typename> class Comparator> | |
struct PairComparator | |
{ | |
template<class A, class B> | |
using First = Comparator<typename A::first, typename B::first>; | |
template<class A, class B> | |
using Second = Comparator<typename A::second, typename B::second>; | |
}; | |
template<class Pair> | |
struct PairExtractorFirst; | |
template<class A, class B> | |
struct PairExtractorFirst<Pair<A, B>> | |
{ | |
using type = A; | |
}; | |
template<class Pair> | |
struct PairExtractorSecond; | |
template<class A, class B> | |
struct PairExtractorSecond<Pair<A, B>> | |
{ | |
using type = B; | |
}; | |
template<class Perm> | |
class InvertPermutation; | |
template<std::size_t Bits, typename Uint<Bits>::least ...Vals> | |
class InvertPermutation<Perm<Bits, Vals...>> | |
{ | |
using Element = typename Perm<Bits, Vals...>::Element; | |
using InversePermSeq = typename QuickSort< | |
typename Zip< | |
TypeList<std::integral_constant<Element, Vals>...>, | |
typename IntegerSequenceConverter< | |
std::make_index_sequence<sizeof...(Vals)>, TypeList>::type | |
>::type, | |
PairComparator<IntegralComparator>::template First | |
>::type; | |
template<class ...InverseVals> | |
using TransformAdapter = Perm<Bits, Element(InverseVals::second::value)...>; | |
public: | |
using type = typename Transform<InversePermSeq, TransformAdapter>::type; | |
}; | |
template<class Perm> | |
using InversePerm = typename InvertPermutation<Perm>::type; | |
#define ROTL8(x,shift) std::uint8_t(((x) << (shift)) | ((x) >> (8 - (shift)))) | |
template<std::uint8_t BaseQ> | |
struct AesSBoxQ | |
{ | |
// Divide q by 3 (equals multiplication by 0xf6) | |
static constexpr std::uint8_t Q0 = BaseQ ^ std::uint8_t(BaseQ << 1); | |
static constexpr std::uint8_t Q1 = Q0 ^ std::uint8_t(Q0 << 2); | |
static constexpr std::uint8_t Q2 = Q1 ^ std::uint8_t(Q1 << 4); | |
static constexpr std::uint8_t Q = Q2 ^ std::uint8_t(Q2 & 0x80 ? 0x09 : 0); | |
}; | |
template<std::size_t N> | |
struct AesSBoxGeneratorHelper | |
{ | |
using Base = AesSBoxGeneratorHelper<N-1>; | |
// Multiply p by 3 | |
static constexpr std::uint8_t P = Base::P ^ std::uint8_t(Base::P << 1) ^ std::uint8_t(Base::P & 0x80 ? 0x1b : 0); | |
// Divide q by 3 (equals multiplication by 0xf6) | |
static constexpr std::uint8_t Q = AesSBoxQ<Base::Q>::Q; | |
// Compute the affine transformation | |
using X = std::integral_constant< | |
std::uint8_t, (Q ^ ROTL8(Q, 1) ^ ROTL8(Q, 2) ^ ROTL8(Q, 3) ^ ROTL8(Q, 4) ^ 0x63)>; | |
public: | |
using type = typename Append< | |
typename Base::type, | |
Pair<std::integral_constant<std::uint8_t, P>, X> | |
>::type; | |
}; | |
template<> | |
struct AesSBoxGeneratorHelper<0> | |
{ | |
static constexpr std::uint8_t P = 1; | |
static constexpr std::uint8_t Q = 1; | |
using X = std::integral_constant<std::uint8_t, 0x63>; | |
public: | |
using type = TypeList<Pair<std::integral_constant<std::uint8_t, 0>, X>>; | |
}; | |
template<class T, class List> | |
struct SBox; | |
template<class T, T ...Vals> | |
struct SBox<T, TypeList<std::integral_constant<T, Vals>...>> | |
{ | |
static_assert(sizeof...(Vals) > 0, ""); | |
using type = Perm<RequiredBits<sizeof...(Vals) - 1>::value, Vals...>; | |
}; | |
using AesSBox = typename SBox<std::uint8_t, | |
typename Map< | |
typename QuickSort< | |
typename AesSBoxGeneratorHelper<255>::type, | |
PairComparator<IntegralComparator>::template First | |
>::type, | |
PairExtractorSecond | |
>::type | |
>::type; | |
using AesInverseSBox = typename InvertPermutation<AesSBox>::type; | |
#include <cstdio> | |
int main() { | |
std::printf("FBOX = {"); | |
for (std::size_t i = 0; i < sizeof(AesSBox::P); ++i) | |
{ | |
if (i % 16 == 0) | |
{ | |
std::printf("\n "); | |
}; | |
std::printf(" 0x%02x,", AesSBox::P[i]); | |
} | |
std::printf("\n}\n"); | |
std::printf("\nIBOX = {"); | |
for (std::size_t i = 0; i < sizeof(AesInverseSBox::P); ++i) | |
{ | |
if (i % 16 == 0) | |
{ | |
std::printf("\n "); | |
}; | |
std::printf(" 0x%02x,", AesInverseSBox::P[i]); | |
} | |
std::printf("\n}\n"); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment