Skip to content

Instantly share code, notes, and snippets.

@Frooxius
Created August 26, 2021 00:55
Show Gist options
  • Save Frooxius/c3e1a181376a3b97bea53e7efd16f992 to your computer and use it in GitHub Desktop.
Save Frooxius/c3e1a181376a3b97bea53e7efd16f992 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Mono.Cecil;
using Mono.Cecil.Cil;
using Mono.Cecil.Rocks;
using Mono.Cecil.Pdb;
using System.IO;
using System.ComponentModel;
namespace PostX
{
public static class NumericsVectorsPostprocessor
{
public static void Process(string path)
{
var assembly = AssemblyDefinition.ReadAssembly(path, new ReaderParameters()
{
ReadWrite = true
});
var module = assembly.MainModule;
int patched = 0;
int total = 0;
foreach(var type in module.Types)
{
if (!type.FullName.Contains("Vector"))
continue;
Console.WriteLine(type.FullName);
foreach (var method in type.Methods)
{
var variables = method.Body.Variables.Count;
method.Body.InitLocals = false;
method.AggressiveInlining = true;
Console.WriteLine("\t" + method.FullName);
var il = method.Body.GetILProcessor();
for(int i = 0; i < il.Body.Instructions.Count; i++)
{
var op = il.Body.Instructions[i];
if (op.OpCode.Code == Code.Box)
{
if (HandleBoxing(method, variables, il, ref i))
{
Console.WriteLine($"\t\tPATCHED: {op}");
patched++;
}
else
Console.WriteLine($"\t\tUNKNOWN: {op}");
total++;
}
}
}
}
Console.WriteLine($"Patched: {patched} out of {total}");
assembly.Write();
assembly.Dispose();
}
static bool HandleBoxing(MethodDefinition method, int originalVariableCount, ILProcessor il, ref int index)
{
var name = method.FullName;
var op = il.Body.Instructions[index];
var prevOp = il.Body.Instructions[index - 1];
var nextOp = il.Body.Instructions[index + 1];
if (op.Operand is GenericParameter genericParam)
{
if(IsParameter(method, prevOp.OpCode.Code, out int paramIndex) && nextOp.OpCode.Code == Code.Unbox_Any)
{
var dereferenceOp = TypeToDereference(nextOp.Operand as TypeReference);
if (dereferenceOp.Code == Code.Nop)
return false;
il.Body.Instructions[index - 1] = Instruction.Create(OpCodes.Ldarga_S, method.Parameters[paramIndex]);
il.Body.Instructions[index] = Instruction.Create(OpCodes.Conv_U);
il.Body.Instructions[index + 1] = Instruction.Create(dereferenceOp);
return true;
}
if(IsComputation(prevOp.OpCode.Code) && nextOp.OpCode.Code == Code.Unbox_Any)
{
var variableType = (TypeReference)genericParam;
var variable = method.Body.Variables.Skip(originalVariableCount).FirstOrDefault(v => v.VariableType == variableType);
if (variable == null)
{
variable = new VariableDefinition(variableType);
method.Body.Variables.Add(variable);
}
var dereferenceOp = TypeToDereference(nextOp.Operand as TypeReference);
if (dereferenceOp.Code == Code.Nop)
return false;
il.Body.Instructions[index] = Instruction.Create(OpCodes.Stloc_S, variable);
il.Body.Instructions[index + 1] = Instruction.Create(OpCodes.Ldloca_S, variable);
il.InsertAfter(index + 1, Instruction.Create(OpCodes.Conv_U));
il.InsertAfter(index + 2, Instruction.Create(dereferenceOp));
return true;
}
if(prevOp.OpCode.Code == Code.Ldelem_Any && nextOp.OpCode.Code == Code.Unbox_Any)
{
var dereferenceOp = TypeToDereference(nextOp.Operand as TypeReference);
if (dereferenceOp.Code == Code.Nop)
return false;
il.Body.Instructions[index - 1] = Instruction.Create(OpCodes.Ldelema, genericParam);
il.Body.Instructions[index] = Instruction.Create(OpCodes.Conv_U);
il.Body.Instructions[index + 1] = Instruction.Create(dereferenceOp);
return true;
}
}
if (nextOp.OpCode.Code == Code.Unbox_Any && nextOp.Operand is GenericParameter nextGenericParam)
{
if (IsDereference(prevOp.OpCode.Code))
{
il.Body.Instructions[index] = Instruction.Create(OpCodes.Ldobj, nextGenericParam);
il.RemoveAt(index + 1);
il.RemoveAt(index - 1);
return true;
}
if(IsConvert(prevOp.OpCode.Code) || IsComputation(prevOp.OpCode.Code))
{
// add local field to store it
var variableType = (TypeReference)op.Operand;
var variable = method.Body.Variables.Skip(originalVariableCount).FirstOrDefault(v => v.VariableType == variableType);
if(variable == null)
{
variable = new VariableDefinition(variableType);
method.Body.Variables.Add(variable);
}
il.Body.Instructions[index] = Instruction.Create(OpCodes.Stloc_S, variable);
il.Body.Instructions[index + 1] = Instruction.Create(OpCodes.Ldloca_S, variable);
il.InsertAfter(index + 1, Instruction.Create(OpCodes.Conv_U));
il.InsertAfter(index + 2, Instruction.Create(OpCodes.Ldobj, nextGenericParam));
return true;
}
return true;
}
if(nextOp.OpCode.Code == Code.Unbox_Any && (op.Operand as TypeReference) == (nextOp.Operand as TypeReference))
{
il.RemoveAt(index + 1);
il.RemoveAt(index);
return true;
}
return false;
}
static bool IsComputation(Code code)
{
switch(code)
{
case Code.Add:
case Code.Mul:
case Code.Call:
return true;
default:
return false;
}
}
static bool IsParameter(MethodDefinition method, Code code, out int index)
{
var isStatic = method.IsStatic;
switch(code)
{
case Code.Ldarg_0:
if (!isStatic)
{
index = -1;
return false;
}
index = 0;
return true;
case Code.Ldarg_1:
index = isStatic ? 1 : 0;
return true;
case Code.Ldarg_2:
index = isStatic ? 2 : 1;
return true;
case Code.Ldarg_3:
index = isStatic ? 3 : 2;
return true;
default:
index = -1;
return false;
}
}
static bool IsConvert(Code code)
{
switch(code)
{
case Code.Conv_I:
case Code.Conv_I1:
case Code.Conv_I2:
case Code.Conv_I4:
case Code.Conv_I8:
case Code.Conv_R4:
case Code.Conv_R8:
case Code.Conv_U:
case Code.Conv_U1:
case Code.Conv_U2:
case Code.Conv_U4:
case Code.Conv_U8:
return true;
default:
return false;
}
}
static bool IsDereference(Code code)
{
switch(code)
{
case Code.Ldind_I:
case Code.Ldind_I1:
case Code.Ldind_I2:
case Code.Ldind_I4:
case Code.Ldind_I8:
case Code.Ldind_R4:
case Code.Ldind_R8:
case Code.Ldind_Ref:
case Code.Ldind_U1:
case Code.Ldind_U2:
case Code.Ldind_U4:
return true;
default:
return false;
}
}
static OpCode TypeToDereference(TypeReference type)
{
return type.Name switch
{
"Byte" => OpCodes.Ldind_U1,
"UInt16" => OpCodes.Ldind_U2,
"UInt32" => OpCodes.Ldind_U4,
"UInt64" => OpCodes.Ldind_I8,
"SByte" => OpCodes.Ldind_I1,
"Int16" => OpCodes.Ldind_I2,
"Int32" => OpCodes.Ldind_I4,
"Int64" => OpCodes.Ldind_I8,
"Single" => OpCodes.Ldind_R4,
"Double" => OpCodes.Ldind_R8,
_ => OpCodes.Nop,
};
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment