Skip to content

Instantly share code, notes, and snippets.

@leonardosnt
Created May 30, 2017 13:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leonardosnt/86e8f1209613992cb3e0bcff3b71f8f5 to your computer and use it in GitHub Desktop.
Save leonardosnt/86e8f1209613992cb3e0bcff3b71f8f5 to your computer and use it in GitHub Desktop.
using System;
using System.Linq;
using System.Text;
using Mono.Cecil;
using Mono.Cecil.Cil;
using Mono.Cecil.Rocks;
using Mono.Collections.Generic;
// 23/07/2016 leonardosnt
namespace ConsoleApplication2 {
public class NotNullReplacer {
public static void Main(string[] args) {
if (args.Length < 3) {
Console.WriteLine("Missing args: [assembly_file] [out_file] [notnull_type]");
Environment.Exit(0);
}
var asmFile = args[0];
var outFile = args[1];
var notNullType = args[2];
Console.WriteLine($"asmFile = '{asmFile}'" );
Console.WriteLine($"outFile = '{outFile}'" );
Console.WriteLine($"notNullType = '{notNullType}'" );
var asmDef = AssemblyDefinition.ReadAssembly(asmFile);
var methods = asmDef.MainModule.Types
.Where(t => t.IsClass)
.SelectMany(t => t.Methods)
.Where(m => !m.IsCompilerControlled && !m.IsAbstract && m.HasBody)
.Where(m => m.Parameters.Count > 0);
Func<int, Collection<VariableDefinition>, Instruction> GetStLocFor = (i, vars) => {
switch (i) {
case 0:
return Instruction.Create(OpCodes.Stloc_0);
case 1:
return Instruction.Create(OpCodes.Stloc_1);
case 2:
return Instruction.Create(OpCodes.Stloc_2);
case 3:
return Instruction.Create(OpCodes.Stloc_3);
default:
return Instruction.Create(OpCodes.Stloc_S, vars[i]);
}
};
Func<int, Collection<VariableDefinition>, Instruction> GetLdLocFor = (i, vars) => {
switch (i) {
case 0:
return Instruction.Create(OpCodes.Ldloc_0);
case 1:
return Instruction.Create(OpCodes.Ldloc_1);
case 2:
return Instruction.Create(OpCodes.Ldloc_2);
case 3:
return Instruction.Create(OpCodes.Ldloc_3);
default:
return Instruction.Create(OpCodes.Ldloc_S, vars[i]);
}
};
Func<MethodDefinition, string> prettyParams = md => {
var sb = new StringBuilder();
foreach (var p in md.Parameters) {
var name = p.ParameterType.FullName;
switch (name) {
case "System.String": name = "string"; break;
case "System.Int32": name = "int"; break;
case "System.Int64": name = "long"; break;
case "System.UInt32": name = "uint"; break;
case "System.UInt64": name = "ulong"; break;
}
sb.Append(name);
sb.Append(" ");
sb.Append(p.Name);
sb.Append(", ");
}
if (sb.Length > 2)
sb.Remove(sb.Length - 2, 2);
return sb.ToString();
};
methods.ForEach(md => {
var paramIdx = md.Body.Variables.Count;
md.Parameters
.Where(m => m.CustomAttributes.Any(a => a.AttributeType.FullName.Equals(notNullType)))
.Where(m => !m.ParameterType.IsPrimitive)
.Reverse()
.ForEach(p => {
md.Body.Variables.Add(new VariableDefinition(asmDef.MainModule.Import(typeof(bool))));
md.Body.SimplifyMacros();
var instrs = md.Body.Instructions;
var sIdx = 0;
var argNullExcepType = asmDef.MainModule.Import(typeof(ArgumentException).GetConstructor(new[] { typeof(string) }));
var checkInstructions = new[] {
Instruction.Create(OpCodes.Ldarg, p),
Instruction.Create(OpCodes.Ldnull),
Instruction.Create(OpCodes.Ceq),
GetStLocFor(paramIdx, md.Body.Variables),
GetLdLocFor(paramIdx, md.Body.Variables),
Instruction.Create(OpCodes.Brfalse_S, instrs.First()),
Instruction.Create(OpCodes.Nop),
Instruction.Create(OpCodes.Ldstr, $"Argument for [NotNull] parameter '{p.Name}' " +
$"of '{md.DeclaringType}::{md.Name}({prettyParams(md)})' must not be null."),
Instruction.Create(OpCodes.Newobj, argNullExcepType),
Instruction.Create(OpCodes.Throw),
};
checkInstructions.ForEach(i => {
Console.WriteLine(i);
instrs.Insert(sIdx++, i);
});
md.Body.OptimizeMacros();
paramIdx++;
});
});
asmDef.Write(outFile);
Console.ReadKey();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment