Skip to content

Instantly share code, notes, and snippets.

@lbmaian
Created December 16, 2019 06:21
Show Gist options
  • Save lbmaian/1b425c1c9d7ec8e5869162d463174742 to your computer and use it in GitHub Desktop.
Save lbmaian/1b425c1c9d7ec8e5869162d463174742 to your computer and use it in GitHub Desktop.
Proof of concept of static constructor patching in RimWorld
using System;
using System.Collections.Generic;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
using Harmony;
using Harmony.ILCopying;
using Verse;
// Outputs in order:
// TestMod
// Bar
// Baz
// Notably does NOT output: Foo
namespace TestMod
{
class TestMod : Mod
{
const bool DEBUG = false;
public TestMod(ModContentPack content) : base(content)
{
Log.Message("TestMod");
HarmonyInstance.DEBUG = DEBUG;
try
{
var harmony = HarmonyInstance.Create("TestMod");
harmony.Patch(typeof(StaticConstructorOnStartupUtility).GetMethod(nameof(StaticConstructorOnStartupUtility.CallAll)),
transpiler: new HarmonyMethod(typeof(TestMod), nameof(StaticConstructorOnStartupUtilityTranspiler)));
var targetType = typeof(TestStaticClass1);
var assemblyName = new AssemblyName(GetType().Assembly.GetName().Name + "Patched");
var assemblyBuilder = AppDomain.CurrentDomain.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run);
var moduleBuilder = assemblyBuilder.DefineDynamicModule(assemblyName.Name);
var typeBuilder = moduleBuilder.DefineType(targetType.FullName, targetType.Attributes);
var attributeBuilder = new CustomAttributeBuilder(typeof(StaticConstructorOnStartup).GetConstructor(Type.EmptyTypes), new object[0]);
typeBuilder.SetCustomAttribute(attributeBuilder);
var staticConstructor = targetType.GetConstructor(BindingFlags.Static | BindingFlags.NonPublic, null, Type.EmptyTypes, null);
var constructorBuilder = typeBuilder.DefineConstructor(staticConstructor.Attributes, staticConstructor.CallingConvention, Type.EmptyTypes);
var generator = constructorBuilder.GetILGenerator();
var methodCopier = new MethodCopier(staticConstructor, generator);
methodCopier.AddTranspiler(GetType().GetMethod(nameof(StaticConstructorTranspiler), AccessTools.all));
var endLabels = new List<Label>();
var endBlocks = new List<ExceptionBlock>();
methodCopier.Finalize(endLabels, endBlocks);
foreach (var label in endLabels)
Emitter.MarkLabel(generator, label);
foreach (var block in endBlocks)
Emitter.MarkBlockAfter(generator, block);
Emitter.Emit(generator, OpCodes.Ret);
typeBuilder.CreateType();
content.assemblies.loadedAssemblies.Add(assemblyBuilder);
}
finally
{
HarmonyInstance.DEBUG = false;
}
}
static IEnumerable<CodeInstruction> StaticConstructorTranspiler(IEnumerable<CodeInstruction> instructions)
{
foreach (var instruction in instructions)
{
if (instruction.operand is "Foo")
yield return new CodeInstruction(OpCodes.Ldstr, "Baz");
else
yield return instruction;
}
}
static readonly MethodInfo typeHandleGetter = typeof(Type).GetProperty(nameof(Type.TypeHandle)).GetGetMethod();
static readonly MethodInfo runClassConstructorMethod = typeof(RuntimeHelpers).GetMethod(nameof(RuntimeHelpers.RunClassConstructor));
static IEnumerable<CodeInstruction> StaticConstructorOnStartupUtilityTranspiler(IEnumerable<CodeInstruction> instructions, ILGenerator generator)
{
var instructionList = (List<CodeInstruction>)instructions;
var typeHandleGetIndex = instructionList.FindIndex(instruction => instruction.operand == typeHandleGetter);
var runClassConstructorCallIndex = instructionList.FindIndex(typeHandleGetIndex + 1, instruction => instruction.operand == runClassConstructorMethod);
instructionList.InsertRange(typeHandleGetIndex, new CodeInstruction[]
{
new CodeInstruction(OpCodes.Dup), // type in type.TypeHandle
new CodeInstruction(OpCodes.Call, typeof(TestMod).GetMethod(nameof(IsExcludedStaticConstructorOnStartupType), AccessTools.all)),
new CodeInstruction(OpCodes.Brtrue_S, instructionList[runClassConstructorCallIndex + 1].labels[0]),
});
return instructionList;
}
static bool IsExcludedStaticConstructorOnStartupType(Type type)
{
return type == typeof(TestStaticClass1);
}
}
[StaticConstructorOnStartup]
static class TestStaticClass1
{
static TestStaticClass1()
{
Log.Message("Foo");
}
}
[StaticConstructorOnStartup]
static class TestStaticClass2
{
static TestStaticClass2()
{
Log.Message("Bar");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment