Created
December 16, 2019 06:21
-
-
Save lbmaian/1b425c1c9d7ec8e5869162d463174742 to your computer and use it in GitHub Desktop.
Proof of concept of static constructor patching in RimWorld
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Reflection; | |
using System.Reflection.Emit; | |
using System.Runtime.CompilerServices; | |
using Harmony; | |
using Harmony.ILCopying; | |
using Verse; | |
// Outputs in order: | |
// TestMod | |
// Bar | |
// Baz | |
// Notably does NOT output: Foo | |
namespace TestMod | |
{ | |
class TestMod : Mod | |
{ | |
const bool DEBUG = false; | |
public TestMod(ModContentPack content) : base(content) | |
{ | |
Log.Message("TestMod"); | |
HarmonyInstance.DEBUG = DEBUG; | |
try | |
{ | |
var harmony = HarmonyInstance.Create("TestMod"); | |
harmony.Patch(typeof(StaticConstructorOnStartupUtility).GetMethod(nameof(StaticConstructorOnStartupUtility.CallAll)), | |
transpiler: new HarmonyMethod(typeof(TestMod), nameof(StaticConstructorOnStartupUtilityTranspiler))); | |
var targetType = typeof(TestStaticClass1); | |
var assemblyName = new AssemblyName(GetType().Assembly.GetName().Name + "Patched"); | |
var assemblyBuilder = AppDomain.CurrentDomain.DefineDynamicAssembly(assemblyName, AssemblyBuilderAccess.Run); | |
var moduleBuilder = assemblyBuilder.DefineDynamicModule(assemblyName.Name); | |
var typeBuilder = moduleBuilder.DefineType(targetType.FullName, targetType.Attributes); | |
var attributeBuilder = new CustomAttributeBuilder(typeof(StaticConstructorOnStartup).GetConstructor(Type.EmptyTypes), new object[0]); | |
typeBuilder.SetCustomAttribute(attributeBuilder); | |
var staticConstructor = targetType.GetConstructor(BindingFlags.Static | BindingFlags.NonPublic, null, Type.EmptyTypes, null); | |
var constructorBuilder = typeBuilder.DefineConstructor(staticConstructor.Attributes, staticConstructor.CallingConvention, Type.EmptyTypes); | |
var generator = constructorBuilder.GetILGenerator(); | |
var methodCopier = new MethodCopier(staticConstructor, generator); | |
methodCopier.AddTranspiler(GetType().GetMethod(nameof(StaticConstructorTranspiler), AccessTools.all)); | |
var endLabels = new List<Label>(); | |
var endBlocks = new List<ExceptionBlock>(); | |
methodCopier.Finalize(endLabels, endBlocks); | |
foreach (var label in endLabels) | |
Emitter.MarkLabel(generator, label); | |
foreach (var block in endBlocks) | |
Emitter.MarkBlockAfter(generator, block); | |
Emitter.Emit(generator, OpCodes.Ret); | |
typeBuilder.CreateType(); | |
content.assemblies.loadedAssemblies.Add(assemblyBuilder); | |
} | |
finally | |
{ | |
HarmonyInstance.DEBUG = false; | |
} | |
} | |
static IEnumerable<CodeInstruction> StaticConstructorTranspiler(IEnumerable<CodeInstruction> instructions) | |
{ | |
foreach (var instruction in instructions) | |
{ | |
if (instruction.operand is "Foo") | |
yield return new CodeInstruction(OpCodes.Ldstr, "Baz"); | |
else | |
yield return instruction; | |
} | |
} | |
static readonly MethodInfo typeHandleGetter = typeof(Type).GetProperty(nameof(Type.TypeHandle)).GetGetMethod(); | |
static readonly MethodInfo runClassConstructorMethod = typeof(RuntimeHelpers).GetMethod(nameof(RuntimeHelpers.RunClassConstructor)); | |
static IEnumerable<CodeInstruction> StaticConstructorOnStartupUtilityTranspiler(IEnumerable<CodeInstruction> instructions, ILGenerator generator) | |
{ | |
var instructionList = (List<CodeInstruction>)instructions; | |
var typeHandleGetIndex = instructionList.FindIndex(instruction => instruction.operand == typeHandleGetter); | |
var runClassConstructorCallIndex = instructionList.FindIndex(typeHandleGetIndex + 1, instruction => instruction.operand == runClassConstructorMethod); | |
instructionList.InsertRange(typeHandleGetIndex, new CodeInstruction[] | |
{ | |
new CodeInstruction(OpCodes.Dup), // type in type.TypeHandle | |
new CodeInstruction(OpCodes.Call, typeof(TestMod).GetMethod(nameof(IsExcludedStaticConstructorOnStartupType), AccessTools.all)), | |
new CodeInstruction(OpCodes.Brtrue_S, instructionList[runClassConstructorCallIndex + 1].labels[0]), | |
}); | |
return instructionList; | |
} | |
static bool IsExcludedStaticConstructorOnStartupType(Type type) | |
{ | |
return type == typeof(TestStaticClass1); | |
} | |
} | |
[StaticConstructorOnStartup] | |
static class TestStaticClass1 | |
{ | |
static TestStaticClass1() | |
{ | |
Log.Message("Foo"); | |
} | |
} | |
[StaticConstructorOnStartup] | |
static class TestStaticClass2 | |
{ | |
static TestStaticClass2() | |
{ | |
Log.Message("Bar"); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment