Skip to content

Instantly share code, notes, and snippets.

@Thealexbarney
Last active September 17, 2022 22:05
Show Gist options
  • Save Thealexbarney/9f75883786a9f3100408ff795fb95d85 to your computer and use it in GitHub Desktop.
Save Thealexbarney/9f75883786a9f3100408ff795fb95d85 to your computer and use it in GitHub Desktop.
C# AES-NI 128-bit
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
namespace LibHac.Crypto
{
public class AesContext : IAesContext
{
private Vector128<byte>[] RoundKeys { get; }
public byte[] Iv { get; } = new byte[0x10];
public AesContext(Span<byte> key)
{
RoundKeys = KeyExpansion(key);
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public void EncryptEcb(Span<byte> data)
{
Vector128<byte>[] keys = RoundKeys;
Span<Vector128<byte>> blocks = MemoryMarshal.Cast<byte, Vector128<byte>>(data);
// Makes the JIT remove all the other range checks on keys
_ = keys[10];
for (int i = 0; i < blocks.Length; i++)
{
Vector128<byte> b = blocks[i];
b = Sse2.Xor(b, keys[0]);
b = Aes.Encrypt(b, keys[1]);
b = Aes.Encrypt(b, keys[2]);
b = Aes.Encrypt(b, keys[3]);
b = Aes.Encrypt(b, keys[4]);
b = Aes.Encrypt(b, keys[5]);
b = Aes.Encrypt(b, keys[6]);
b = Aes.Encrypt(b, keys[7]);
b = Aes.Encrypt(b, keys[8]);
b = Aes.Encrypt(b, keys[9]);
b = Aes.EncryptLast(b, keys[10]);
blocks[i] = b;
}
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public void DecryptEcb(Span<byte> data)
{
Vector128<byte>[] keys = RoundKeys;
Span<Vector128<byte>> blocks = MemoryMarshal.Cast<byte, Vector128<byte>>(data);
// Makes the JIT remove all the other range checks on keys
_ = keys[19];
for (int i = 0; i < blocks.Length; i++)
{
Vector128<byte> b = blocks[i];
b = Sse2.Xor(b, keys[10]);
b = Aes.Decrypt(b, keys[19]);
b = Aes.Decrypt(b, keys[18]);
b = Aes.Decrypt(b, keys[17]);
b = Aes.Decrypt(b, keys[16]);
b = Aes.Decrypt(b, keys[15]);
b = Aes.Decrypt(b, keys[14]);
b = Aes.Decrypt(b, keys[13]);
b = Aes.Decrypt(b, keys[12]);
b = Aes.Decrypt(b, keys[11]);
b = Aes.DecryptLast(b, keys[0]);
blocks[i] = b;
}
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public void EncryptCbc(Span<byte> data)
{
Vector128<byte>[] keys = RoundKeys;
Span<Vector128<byte>> blocks = MemoryMarshal.Cast<byte, Vector128<byte>>(data);
var b = Unsafe.ReadUnaligned<Vector128<byte>>(ref Iv[0]);
// Makes the JIT remove all the other range checks on keys
_ = keys[10];
for (int i = 0; i < blocks.Length; i++)
{
b = Sse2.Xor(b, blocks[i]);
b = Sse2.Xor(b, keys[0]);
b = Aes.Encrypt(b, keys[1]);
b = Aes.Encrypt(b, keys[2]);
b = Aes.Encrypt(b, keys[3]);
b = Aes.Encrypt(b, keys[4]);
b = Aes.Encrypt(b, keys[5]);
b = Aes.Encrypt(b, keys[6]);
b = Aes.Encrypt(b, keys[7]);
b = Aes.Encrypt(b, keys[8]);
b = Aes.Encrypt(b, keys[9]);
b = Aes.EncryptLast(b, keys[10]);
blocks[i] = b;
}
Unsafe.WriteUnaligned(ref Iv[0], b);
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public void DecryptCbc(Span<byte> data)
{
Vector128<byte>[] keys = RoundKeys;
Span<Vector128<byte>> blocks = MemoryMarshal.Cast<byte, Vector128<byte>>(data);
var iv = Unsafe.ReadUnaligned<Vector128<byte>>(ref Iv[0]);
// Makes the JIT remove all the other range checks on keys
_ = keys[19];
for (int i = 0; i < blocks.Length; i++)
{
Vector128<byte> b = blocks[i];
Vector128<byte> nextIv = b;
b = Sse2.Xor(b, keys[10]);
b = Aes.Decrypt(b, keys[19]);
b = Aes.Decrypt(b, keys[18]);
b = Aes.Decrypt(b, keys[17]);
b = Aes.Decrypt(b, keys[16]);
b = Aes.Decrypt(b, keys[15]);
b = Aes.Decrypt(b, keys[14]);
b = Aes.Decrypt(b, keys[13]);
b = Aes.Decrypt(b, keys[12]);
b = Aes.Decrypt(b, keys[11]);
b = Aes.DecryptLast(b, keys[0]);
b = Sse2.Xor(b, iv);
iv = nextIv;
blocks[i] = b;
}
Unsafe.WriteUnaligned(ref Iv[0], iv);
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public void EncryptCtr(Span<byte> data)
{
Vector128<byte>[] keys = RoundKeys;
Span<Vector128<byte>> blocks = MemoryMarshal.Cast<byte, Vector128<byte>>(data);
Vector128<byte> byteSwapMask = Vector128.Create((ulong)0x706050403020100, 0x8090A0B0C0D0E0F).AsByte();
Vector128<ulong> inc = Vector128.Create((ulong)0, 1);
// Makes the JIT remove all the other range checks on keys
_ = keys[10];
var iv = Unsafe.ReadUnaligned<Vector128<byte>>(ref Iv[0]);
Vector128<ulong> bSwappedIv = Ssse3.Shuffle(iv, byteSwapMask).AsUInt64();
for (int i = 0; i < blocks.Length; i++)
{
Vector128<byte> b = Sse2.Xor(iv, keys[0]);
b = Aes.Encrypt(b, keys[1]);
b = Aes.Encrypt(b, keys[2]);
b = Aes.Encrypt(b, keys[3]);
b = Aes.Encrypt(b, keys[4]);
b = Aes.Encrypt(b, keys[5]);
b = Aes.Encrypt(b, keys[6]);
b = Aes.Encrypt(b, keys[7]);
b = Aes.Encrypt(b, keys[8]);
b = Aes.Encrypt(b, keys[9]);
b = Aes.EncryptLast(b, keys[10]);
blocks[i] = Sse2.Xor(blocks[i], b);
// Increase the counter
bSwappedIv = Sse2.Add(bSwappedIv, inc);
iv = Ssse3.Shuffle(bSwappedIv.AsByte(), byteSwapMask);
}
Unsafe.WriteUnaligned(ref Iv[0], iv);
if ((data.Length & 0xF) != 0)
{
EncryptCtrPartialBlock(data.Slice(blocks.Length * 0x10));
}
}
private void EncryptCtrPartialBlock(Span<byte> data)
{
Span<byte> counter = stackalloc byte[0x10];
Iv.CopyTo(counter);
EncryptEcb(counter);
Util.XorArrays(data, counter);
for (int i = 0; i < Iv.Length; i++)
{
if (++Iv[i] != 0) break;
}
}
private static Vector128<byte>[] KeyExpansion(Span<byte> key)
{
var keys = new Vector128<byte>[20];
keys[0] = Unsafe.ReadUnaligned<Vector128<byte>>(ref key[0]);
MakeRoundKey(keys, 1, 0x01);
MakeRoundKey(keys, 2, 0x02);
MakeRoundKey(keys, 3, 0x04);
MakeRoundKey(keys, 4, 0x08);
MakeRoundKey(keys, 5, 0x10);
MakeRoundKey(keys, 6, 0x20);
MakeRoundKey(keys, 7, 0x40);
MakeRoundKey(keys, 8, 0x80);
MakeRoundKey(keys, 9, 0x1b);
MakeRoundKey(keys, 10, 0x36);
for (int i = 1; i < 10; i++)
{
keys[10 + i] = Aes.InverseMixColumns(keys[i]);
}
return keys;
}
private static void MakeRoundKey(Vector128<byte>[] keys, int i, byte rcon)
{
Vector128<byte> s = keys[i - 1];
Vector128<byte> t = keys[i - 1];
t = Aes.KeygenAssist(t, rcon);
t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte();
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4));
s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8));
keys[i] = Sse2.Xor(s, t);
}
public void SetIv(Span<byte> iv)
{
iv.Slice(0, 0x10).CopyTo(Iv);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment