Skip to content

Instantly share code, notes, and snippets.

@iscgar
Created April 27, 2024 19:16
Show Gist options
  • Save iscgar/7ca116f64737d7df69a72623bec90b33 to your computer and use it in GitHub Desktop.
Save iscgar/7ca116f64737d7df69a72623bec90b33 to your computer and use it in GitHub Desktop.
Compile-time AES S-Box and inverse S-Box generation in C++
// A dirty hack to generate the AES S-Box at compile-time.
// Not optimised. Please do not use.
// Requires C++14 for std::integer_sequence
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include <utility>
template<std::uintmax_t Value>
struct BitSum : public std::integral_constant<std::size_t, (Value & 1) + BitSum<(Value >> 1)>::value> {};
template<>
struct BitSum<0> : public std::integral_constant<std::size_t, 0> {};
template<std::uintmax_t Value>
struct RequiredBits : public std::integral_constant<std::size_t, 1 + RequiredBits<(Value >> 1)>::value> {};
template<>
struct RequiredBits<0> : public std::integral_constant<std::size_t, 0> {};
template<class T>
struct BitSize : public std::integral_constant<std::size_t, sizeof(T) * BitSum<static_cast<unsigned char>(~0)>::value> {};
template<std::size_t Bits>
struct Uint
{
static_assert(Bits > 0, "");
using least = typename Uint<Bits + 1>::least;
};
template<> struct Uint<8> { using least = unsigned char; };
template<> struct Uint<16> { using least = unsigned short; };
template<> struct Uint<32>
{
using least = std::conditional<
(BitSize<unsigned int>::value >= 32), unsigned int, unsigned long
>::type;
};
template<> struct Uint<64> { using least = std::uint_least64_t; };
template<class List, template<typename> class Predicate>
struct All;
template<template<typename...> class List, template<typename> class Predicate>
struct All<List<>, Predicate> : std::integral_constant<bool, true> {};
#if defined(__cplusplus) && __cplusplus >= 201703L
template<template<typename...> class List, class ...T, template<typename> class Predicate>
struct All<List<T...>, Predicate> : std::integral_constant<bool, (Predicate<T>::value && ...)> {};
#else
template<template<typename...> class List, class Head, class ...T, template<typename> class Predicate>
struct All<List<Head, T...>, Predicate> : std::integral_constant<bool, (Predicate<Head>::value && All<List<T...>, Predicate>::value)> {};
#endif
template<class List, template<typename> class Predicate>
struct Any;
template<template<typename...> class List, template<typename> class Predicate>
struct Any<List<>, Predicate> : std::integral_constant<bool, false> {};
#if defined(__cplusplus) && __cplusplus >= 201703L
template<template<typename...> class List, class ...T, template<typename> class Predicate>
struct Any<List<T...>, Predicate> : std::integral_constant<bool, (Predicate<T>::value || ...)> {};
#else
template<template<typename...> class List, class Head, class ...T, template<typename> class Predicate>
struct Any<List<Head, T...>, Predicate> : std::integral_constant<bool, (Predicate<Head>::value || Any<List<T...>, Predicate>::value)> {};
#endif
template<class List, class T>
struct Append;
template<template<typename...> class List, class ...T, class U>
struct Append<List<T...>, U>
{
using type = List<T..., U>;
};
template<class ...Lists>
struct Concat;
template<template<typename...> class List, class ...A, class ...B>
struct Concat<List<A...>, List<B...>>
{
using type = List<A..., B...>;
};
template<template<typename...> class List, class ...A, class ...B, class ...C>
struct Concat<List<A...>, List<B...>, List<C...>>
{
using type = List<A..., B..., C...>;
};
template<class List, template<typename> class Mapper>
struct Map;
template<template<typename...> class List, class ...A, template<typename> class Mapper>
struct Map<List<A...>, Mapper>
{
using type = List<typename Mapper<A>::type...>;
};
template<class List, template<typename, typename> class Folder, class Init>
struct Fold;
template<template<typename...> class List, template<typename, typename> class Folder, class Init>
struct Fold<List<>, Folder, Init>
{
using type = Init;
};
template<template<typename...> class List, class Head, class ...A, template<typename, typename> class Folder, class Init>
struct Fold<List<Head, A...>, Folder, Init>
{
using type = typename Fold<
List<A...>, Folder, typename Folder<Init, Head>::type>::type;
};
template<class From, template<typename...> class To>
struct Transform;
template<template<typename...> class From, class ...A, template<typename...> class To>
struct Transform<From<A...>, To>
{
using type = To<A...>;
};
template<class List, template<typename> class Predicate>
struct Filter;
template<template<typename...> class List, template<typename> class Predicate>
struct Filter<List<>, Predicate>
{
using type = List<>;
};
template<template<typename...> class List, class Head, class ...T, template<typename> class Predicate>
struct Filter<List<Head, T...>, Predicate>
{
using type = typename Concat<
typename std::conditional<Predicate<Head>::value, List<Head>, List<>>::type,
typename Filter<List<T...>, Predicate>::type
>::type;
};
template<class A, class B>
struct Pair
{
using first = A;
using second = B;
};
template<class A, class B>
struct Zip;
template<template<typename...> class List, class ...A, class ...B>
struct Zip<List<A...>, List<B...>>
{
using type = List<Pair<A, B>...>;
};
template<class T, template<typename, typename> class BinPredicate>
struct BinaryToUnaryPredicate
{
template<class U>
using UnaryValueFirst = BinPredicate<U, T>;
template<class U>
using UnaryValueSecond = BinPredicate<T, U>;
};
template<template<typename, typename> class Comparator>
struct ComparatorPredicate
{
template<class A, class B>
using Less = std::integral_constant<bool, (Comparator<A, B>::value < 0)>;
template<class A, class B>
using LessEqual = std::integral_constant<bool, (Comparator<A, B>::value <= 0)>;
template<class A, class B>
using Equal = std::integral_constant<bool, (Comparator<A, B>::value == 0)>;
template<class A, class B>
using Greater = std::integral_constant<bool, (Comparator<A, B>::value > 0)>;
template<class A, class B>
using GreaterEqual = std::integral_constant<bool, (Comparator<A, B>::value >= 0)>;
};
template<class List, template<typename, typename> class Comparator>
class QuickSort;
template<template<typename...> class List, template<typename, typename> class Comparator>
class QuickSort<List<>, Comparator>
{
public:
using type = List<>;
};
template<template<typename...> class List, class Pivot, template<typename, typename> class Comparator>
class QuickSort<List<Pivot>, Comparator>
{
public:
using type = List<Pivot>;
};
template<template<typename...> class List, class Pivot, class Head, class ...T, template<typename, typename> class Comparator>
class QuickSort<List<Pivot, Head, T...>, Comparator>
{
template<template<typename, typename> class Predicate>
using Partition = Filter<List<Head, T...>, BinaryToUnaryPredicate<Pivot, Predicate>::template UnaryValueFirst>;
using LowerPart = typename Partition<ComparatorPredicate<Comparator>::template Less>::type;
using UpperPart = typename Partition<ComparatorPredicate<Comparator>::template GreaterEqual>::type;
public:
using type = typename Concat<
typename QuickSort<LowerPart, Comparator>::type,
List<Pivot>,
typename QuickSort<UpperPart, Comparator>::type
>::type;
};
template<class T, class U>
struct IntegralComparator;
template<class T, T V>
struct IntegralComparator<std::integral_constant<T, V>, std::integral_constant<T, V>> :
std::integral_constant<int, 0> {};
template<class T, T A, T B>
struct IntegralComparator<std::integral_constant<T, A>, std::integral_constant<T, B>> :
std::integral_constant<int, (A < B ? -1 : 1)> {};
template<class FromSeq, template<typename...> class ToSeq>
struct IntegerSequenceConverter;
template<template<typename T, T...> class Seq, class T, T ...Vals, template<typename...> class ToSeq>
struct IntegerSequenceConverter<Seq<T, Vals...>, ToSeq>
{
using type = ToSeq<std::integral_constant<T, Vals>...>;
};
template<class Pair>
struct SwapPair;
template<class A, class B>
struct SwapPair<Pair<A, B>>
{
using type = Pair<B, A>;
};
template<class ...T> struct TypeList {};
template<std::size_t Bits, typename Uint<Bits>::least ...Vals>
struct Perm
{
public:
using Element = typename Uint<Bits>::least;
private:
static constexpr std::size_t N = std::size_t(1) << Bits;
static_assert(sizeof...(Vals) == N, "");
using ValueSeq = TypeList<std::integral_constant<Element, Vals>...>;
template<class V>
using ValueEnforcer = std::integral_constant<bool, (V::value < N)>;
static_assert(All<ValueSeq, ValueEnforcer>::value, "");
using IndexSeq = typename IntegerSequenceConverter<
std::make_index_sequence<sizeof...(Vals)>,
TypeList
>::type;
template<class A, class B>
using XorFolder = std::integral_constant<Element, (A::value ^ B::value)>;
template<class A, class B>
using SumFolder = std::integral_constant<std::uintmax_t, (A::value + B::value)>;
static_assert(
Fold<IndexSeq, XorFolder, std::integral_constant<Element, 0>
>::type::value == Fold<ValueSeq, XorFolder, std::integral_constant<Element, 0>
>::type::value, "");
static_assert(
Fold<IndexSeq, SumFolder, std::integral_constant<std::uintmax_t, 0>
>::type::value == Fold<ValueSeq, SumFolder, std::integral_constant<std::uintmax_t, 0>
>::type::value, "");
public:
static inline constexpr std::size_t size() { return N; }
static constexpr const Element P[sizeof...(Vals)] = {
Vals...
};
};
template<std::size_t Bits, typename Uint<Bits>::least ...Vals>
constexpr const typename Perm<Bits, Vals...>::Element Perm<Bits, Vals...>::P[sizeof...(Vals)];
template<template<typename, typename> class Comparator>
struct PairComparator
{
template<class A, class B>
using First = Comparator<typename A::first, typename B::first>;
template<class A, class B>
using Second = Comparator<typename A::second, typename B::second>;
};
template<class Pair>
struct PairExtractorFirst;
template<class A, class B>
struct PairExtractorFirst<Pair<A, B>>
{
using type = A;
};
template<class Pair>
struct PairExtractorSecond;
template<class A, class B>
struct PairExtractorSecond<Pair<A, B>>
{
using type = B;
};
template<class Perm>
class InvertPermutation;
template<std::size_t Bits, typename Uint<Bits>::least ...Vals>
class InvertPermutation<Perm<Bits, Vals...>>
{
using Element = typename Perm<Bits, Vals...>::Element;
using InversePermSeq = typename QuickSort<
typename Zip<
TypeList<std::integral_constant<Element, Vals>...>,
typename IntegerSequenceConverter<
std::make_index_sequence<sizeof...(Vals)>, TypeList>::type
>::type,
PairComparator<IntegralComparator>::template First
>::type;
template<class ...InverseVals>
using TransformAdapter = Perm<Bits, Element(InverseVals::second::value)...>;
public:
using type = typename Transform<InversePermSeq, TransformAdapter>::type;
};
template<class Perm>
using InversePerm = typename InvertPermutation<Perm>::type;
#define ROTL8(x,shift) std::uint8_t(((x) << (shift)) | ((x) >> (8 - (shift))))
template<std::uint8_t BaseQ>
struct AesSBoxQ
{
// Divide q by 3 (equals multiplication by 0xf6)
static constexpr std::uint8_t Q0 = BaseQ ^ std::uint8_t(BaseQ << 1);
static constexpr std::uint8_t Q1 = Q0 ^ std::uint8_t(Q0 << 2);
static constexpr std::uint8_t Q2 = Q1 ^ std::uint8_t(Q1 << 4);
static constexpr std::uint8_t Q = Q2 ^ std::uint8_t(Q2 & 0x80 ? 0x09 : 0);
};
template<std::size_t N>
struct AesSBoxGeneratorHelper
{
using Base = AesSBoxGeneratorHelper<N-1>;
// Multiply p by 3
static constexpr std::uint8_t P = Base::P ^ std::uint8_t(Base::P << 1) ^ std::uint8_t(Base::P & 0x80 ? 0x1b : 0);
// Divide q by 3 (equals multiplication by 0xf6)
static constexpr std::uint8_t Q = AesSBoxQ<Base::Q>::Q;
// Compute the affine transformation
using X = std::integral_constant<
std::uint8_t, (Q ^ ROTL8(Q, 1) ^ ROTL8(Q, 2) ^ ROTL8(Q, 3) ^ ROTL8(Q, 4) ^ 0x63)>;
public:
using type = typename Append<
typename Base::type,
Pair<std::integral_constant<std::uint8_t, P>, X>
>::type;
};
template<>
struct AesSBoxGeneratorHelper<0>
{
static constexpr std::uint8_t P = 1;
static constexpr std::uint8_t Q = 1;
using X = std::integral_constant<std::uint8_t, 0x63>;
public:
using type = TypeList<Pair<std::integral_constant<std::uint8_t, 0>, X>>;
};
template<class T, class List>
struct SBox;
template<class T, T ...Vals>
struct SBox<T, TypeList<std::integral_constant<T, Vals>...>>
{
static_assert(sizeof...(Vals) > 0, "");
using type = Perm<RequiredBits<sizeof...(Vals) - 1>::value, Vals...>;
};
using AesSBox = typename SBox<std::uint8_t,
typename Map<
typename QuickSort<
typename AesSBoxGeneratorHelper<255>::type,
PairComparator<IntegralComparator>::template First
>::type,
PairExtractorSecond
>::type
>::type;
using AesInverseSBox = typename InvertPermutation<AesSBox>::type;
#include <cstdio>
int main() {
std::printf("FBOX = {");
for (std::size_t i = 0; i < sizeof(AesSBox::P); ++i)
{
if (i % 16 == 0)
{
std::printf("\n ");
};
std::printf(" 0x%02x,", AesSBox::P[i]);
}
std::printf("\n}\n");
std::printf("\nIBOX = {");
for (std::size_t i = 0; i < sizeof(AesInverseSBox::P); ++i)
{
if (i % 16 == 0)
{
std::printf("\n ");
};
std::printf(" 0x%02x,", AesInverseSBox::P[i]);
}
std::printf("\n}\n");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment