Skip to content

Instantly share code, notes, and snippets.

@mjsabby
Last active April 29, 2024 05:42
Show Gist options
  • Save mjsabby/b87b629fa902cd641b955892b1ab9604 to your computer and use it in GitHub Desktop.
Save mjsabby/b87b629fa902cd641b955892b1ab9604 to your computer and use it in GitHub Desktop.
Verify RS256 using more primitive operations to find where time is spent
// <PackageReference Include="SymCryptNative" Version="103.4.2" />
// <PackageReference Include="BenchmarkDotNet" Version="0.13.12" />
namespace VerifyRS256JWTSignature
{
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Running;
using Microsoft.IdentityModel.Tokens;
using System;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
using System.Text;
using unsafe SymCryptRsaCoreEncFunctionPointer = delegate* unmanaged[Stdcall, SuppressGCTransition]<nint, byte*, nint, SYMCRYPT_NUMBER_FORMAT, int, byte*, nint, byte*, nint, int>;
public class Benchmarks : IDisposable
{
private readonly JwtSecurityTokenHandler jwtSecurityTokenHandler;
private readonly RSA rsa;
private readonly string jwt;
private readonly TokenValidationParameters validationParameters;
private readonly SymCryptKeyHandle publicKey;
private readonly unsafe SymCryptRsaCoreEncFunctionPointer encrypt;
public Benchmarks()
{
var dll = Program.GetDllPointer(NativeMethods.SymCrypt);
unsafe
{
encrypt = (SymCryptRsaCoreEncFunctionPointer)Program.GetDllFunction(dll, "SymCryptRsaCoreEnc");
}
byte[] modulus;
byte[] exponent;
rsa = RSA.Create();
var pem = @"MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu1SU1LfVLPHCozMxH2Mo4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0/IzW7yWR7QkrmBL7jTKEn5u+qKhbwKfBstIs+bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyehkd3qqGElvW/VDL5AaWTg0nLVkjRo9z+40RQzuVaE8AkAFmxZzow3x+VJYKdjykkJ0iT9wCS0DRTXu269V264Vf/3jvredZiKRkgwlL9xNAwxXFg0x/XFw005UWVRIkdgcKWTjpBP2dPwVZ4WWC+9aGVd+Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbcmwIDAQAB";
rsa.ImportSubjectPublicKeyInfo(Convert.FromBase64String(pem), out _);
var ec = rsa.ExportParameters(false);
(modulus, exponent) = (ec.Modulus, ec.Exponent);
SYMCRYPT_RSA_PARAMS rsaParams = new()
{
Version = 1,
BitsOfModulus = modulus.Length * 8,
NumOfPrimes = 0,
NumOfPublicExponents = 1
};
const int SYMCRYPT_FLAG_RSAKEY_ENCRYPT = 0x2000;
var key = NativeMethods.SymCryptRsakeyAllocate(in rsaParams, flags: 0);
var err = NativeMethods.SymCryptRsakeySetValue(modulus, modulus.Length, exponent, numberOfExponents: 1, 0, 0, 0, SYMCRYPT_NUMBER_FORMAT.MSB_FIRST, SYMCRYPT_FLAG_RSAKEY_ENCRYPT, key);
if (err != 0)
{
throw new InvalidOperationException("Failed to set RSA key value");
}
publicKey = new SymCryptKeyHandle(rsaParams.BitsOfModulus, key);
jwt = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ";
if (!BenchmarkVerifySHA256Signature())
{
throw new Exception("Invalid signature");
}
validationParameters = new TokenValidationParameters
{
ValidateLifetime = false,
ValidateActor = false,
ValidateIssuer = false,
ValidateAudience = false,
ValidateTokenReplay = false,
ValidateSignatureLast = false,
ValidateWithLKG = false,
ValidateIssuerSigningKey = true,
IssuerSigningKey = new RsaSecurityKey(ec)
};
jwtSecurityTokenHandler = new JwtSecurityTokenHandler();
}
[Benchmark]
public unsafe bool BenchmarkVerifySHA256Signature()
{
var jwtParts = jwt.Split('.');
var dataToVerify2 = Encoding.UTF8.GetBytes($"{jwtParts[0]}.{jwtParts[1]}");
var signature2 = Base64UrlDecode(jwtParts[2]);
return Program.VerifySHA256Signature(encrypt, dataToVerify2, signature2, publicKey);
}
[Benchmark]
public void BenchmarkValidateJWTUsingIdentityModel()
{
_ = jwtSecurityTokenHandler.ValidateToken(jwt, validationParameters, out _);
}
private static byte[] Base64UrlDecode(string base64UrlEncodedString)
{
string output = base64UrlEncodedString;
output = output.Replace('-', '+').Replace('_', '/');
switch (output.Length % 4)
{
case 2:
output += "==";
break;
case 3:
output += "=";
break;
default:
break; // should throw, but meh
}
return Convert.FromBase64String(output);
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
publicKey.Dispose();
}
}
internal static partial class NativeMethods
{
public const string Kernel32 = "kernel32.dll";
public const string SymCrypt = "symcrypt";
[LibraryImport(Kernel32)]
public static partial nint LoadLibraryA(nint dllToLoad);
[LibraryImport(Kernel32)]
public static partial nint GetProcAddress(nint hModule, nint procedureName);
[LibraryImport(Kernel32)]
public static partial int GetLastError();
[LibraryImport(Kernel32)]
[return: MarshalAs(UnmanagedType.Bool)]
public static partial bool FreeLibrary(nint hModule);
[LibraryImport(SymCrypt)]
public static partial int SymCryptRsakeyFree(nint hKey);
[LibraryImport(SymCrypt)]
public static partial nint SymCryptRsakeyAllocate(in SYMCRYPT_RSA_PARAMS rsaParams, int flags);
[LibraryImport(SymCrypt)]
public static partial int SymCryptRsaCoreEncScratchSpace(nint ptr);
[LibraryImport(SymCrypt)]
public static unsafe partial int SymCryptRsakeySetValue(ReadOnlySpan<byte> modulus, nint modulusSize, ReadOnlySpan<byte> exponent, int numberOfExponents, nint _, nint __, int numPrimes, SYMCRYPT_NUMBER_FORMAT format, int flags, nint rsaKey);
}
internal static partial class Program
{
public static void Main()
{
_ = new Benchmarks();
BenchmarkRunner.Run<Benchmarks>();
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public static unsafe bool VerifySHA256Signature(SymCryptRsaCoreEncFunctionPointer encrypt, ReadOnlySpan<byte> dataToVerify, ReadOnlySpan<byte> signature, SymCryptKeyHandle publicKey)
{
// https://datatracker.ietf.org/doc/html/rfc3447#section-9.2
var pkcsPaddingWithSHA256DigestInfo = new ReadOnlySpan<byte>(
[
0x00, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x30, 0x31, 0x30, 0x0D,
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20
]);
var sha256Hash = SHA256.HashData(dataToVerify);
var sha256HashLength = sha256Hash.Length;
// RS256 uses SHA-256 which is 32 bytes
if (sha256HashLength != 32)
{
return false;
}
var bitLength = publicKey.BitLength;
// RS256 requires 2048 bits to be secure
if (bitLength < 2048)
{
return false;
}
int roundedKeySizeInBytes = (bitLength + 7) >> 3;
// check if multiple of 16 bytes, i.e. 2048, 3072, 4096 bit length, etc.
if (((roundedKeySizeInBytes) & 0xF) != 0)
{
return false;
}
if (signature.Length != roundedKeySizeInBytes)
{
return false;
}
int scratchSize = NativeMethods.SymCryptRsaCoreEncScratchSpace(publicKey.Pointer);
const int SYMCRYPT_ASYM_ALIGN_VALUE = 32;
const int extra = SYMCRYPT_ASYM_ALIGN_VALUE - 1;
Span<byte> dest = stackalloc byte[roundedKeySizeInBytes + extra];
Span<byte> scratch = stackalloc byte[scratchSize + extra];
fixed (byte* scratchPtr = &scratch[0])
fixed (byte* destPtr = &dest[0])
fixed (byte* signaturePtr = &signature[0])
{
int destOffset = AlignPointerForSymCrypt(destPtr, SYMCRYPT_ASYM_ALIGN_VALUE);
byte* destAligned = destPtr + destOffset;
dest = dest.Slice(destOffset, roundedKeySizeInBytes);
int scratchOffset = AlignPointerForSymCrypt(scratchPtr, SYMCRYPT_ASYM_ALIGN_VALUE);
byte* scratchAligned = scratchPtr + scratchOffset;
scratch = scratch.Slice(scratchOffset, scratchSize);
if (encrypt(publicKey.Pointer, signaturePtr, signature.Length, SYMCRYPT_NUMBER_FORMAT.MSB_FIRST, 0, destAligned, roundedKeySizeInBytes, scratchAligned, scratchSize) != 0)
{
return false;
}
}
if (!dest[^sha256HashLength..].SequenceEqual(sha256Hash))
{
return false;
}
if (!dest[..(roundedKeySizeInBytes - sha256HashLength)].SequenceEqual(pkcsPaddingWithSHA256DigestInfo))
{
return false;
}
return true;
}
public static nint GetDllPointer(string path)
{
nint dllStringPtr = 0;
try
{
dllStringPtr = Marshal.StringToHGlobalAnsi(path);
return NativeMethods.LoadLibraryA(dllStringPtr);
}
finally
{
if (dllStringPtr != 0)
{
Marshal.FreeHGlobal(dllStringPtr);
}
}
}
public static nint GetDllFunction(nint dll, string functionName)
{
nint functionNamePtr = 0;
try
{
functionNamePtr = Marshal.StringToHGlobalAnsi(functionName);
return NativeMethods.GetProcAddress(dll, functionNamePtr);
}
finally
{
if (functionNamePtr != 0)
{
Marshal.FreeHGlobal(functionNamePtr);
}
}
}
private static byte[] Base64UrlDecode(string base64UrlEncodedString)
{
string output = base64UrlEncodedString;
output = output.Replace('-', '+').Replace('_', '/');
switch (output.Length % 4)
{
case 2:
output += "==";
break;
case 3:
output += "=";
break;
default:
break; // should throw, but meh
}
return Convert.FromBase64String(output);
}
private static unsafe int AlignPointerForSymCrypt(byte* ptr, int alignment)
{
long baseAddress = (long)ptr;
long offset = (alignment - (baseAddress % alignment)) % alignment;
return (int)offset;
}
}
internal enum SYMCRYPT_NUMBER_FORMAT
{
LSB_FIRST = 1,
MSB_FIRST = 2
}
[StructLayout(LayoutKind.Sequential)]
internal struct SYMCRYPT_RSA_PARAMS
{
public int Version { get; set; }
public int BitsOfModulus { get; set; }
public int NumOfPrimes { get; set; }
public int NumOfPublicExponents { get; set; }
}
internal sealed class SymCryptKeyHandle(int bitLength, nint handle) : IDisposable
{
private bool _disposed;
~SymCryptKeyHandle()
{
Dispose(false);
}
public nint Pointer { get; private set; } = handle;
public int BitLength { get; private set; } = bitLength;
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
private void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
// Dispose managed resources
}
if (Pointer != 0)
{
_ = NativeMethods.SymCryptRsakeyFree(Pointer);
Pointer = 0;
}
_disposed = true;
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment