Skip to content

Instantly share code, notes, and snippets.

@kellypleahy
Created May 24, 2011 05:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kellypleahy/988177 to your computer and use it in GitHub Desktop.
Save kellypleahy/988177 to your computer and use it in GitHub Desktop.
CECIL based unit tests.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Text;
using Milliman.MGAlfa.AplInterface;
using Mono.Cecil;
using Mono.Cecil.Cil;
using NUnit.Framework;
using MethodBody = Mono.Cecil.Cil.MethodBody;
namespace Milliman.MGAlfa.Utility.Tests
{
[TestFixture]
public class AssemblyReferenceChecks
{
[Test]
public void Verify_That_Production_Assemblies_Only_Reference_Production_Dependencies()
{
var assemblies = new Dictionary<string, Assembly>();
_GetReferencedAssemblies(assemblies, typeof(TxnEditAin2Category).Assembly);
foreach (var assembly in assemblies.Values)
{
foreach (var referencedAssembly in assembly.GetReferencedAssemblies())
{
var referencedAssemblyName = referencedAssembly.Name;
if (referencedAssemblyName.ToLower().Contains(".tests"))
Assert.Fail(string.Format("Assembly: {0} references a Tests assembly - {1}",
assembly.GetName().Name,
referencedAssemblyName));
if (referencedAssemblyName.ToLower().Contains("nunit"))
Assert.Fail(string.Format("Assembly: {0} references a NUnit assembly - {1}",
assembly.GetName().Name,
referencedAssemblyName));
if (referencedAssemblyName.ToLower().Contains("rhino.mocks"))
Assert.Fail(string.Format("Assembly: {0} references a Rhino.Mocks assembly - {1}",
assembly.GetName().Name,
referencedAssemblyName));
}
}
}
private void _GetReferencedAssemblies(Dictionary<string, Assembly> assemblies, Assembly rootAssembly)
{
assemblies.Add(rootAssembly.FullName, rootAssembly);
foreach (var assemblyName in rootAssembly.GetReferencedAssemblies())
{
Debug.WriteLine(assemblyName.Name);
if (!assemblyName.Name.StartsWith("Milliman.MGAlfa."))
continue;
if (!assemblies.ContainsKey(assemblyName.FullName))
_GetReferencedAssemblies(assemblies, Assembly.Load(assemblyName));
}
}
[Test]
public void VerifyThatNoNotNullArgumentConstraintsAreViolated()
{
var assyDef = AssemblyFactory.GetAssembly(typeof(TxnEditAin2Category).Assembly.Location);
var refVisitor = new RefVisitor();
var alreadyProcessedDictionary = new HashSet<string>();
_ProcessAssembly(assyDef, refVisitor, alreadyProcessedDictionary);
}
#region Supporting reflection (Cecil) goop for VerifyThatNoNotNullArgumentConstraintsAreViolated
private class RefVisitor : BaseReflectionVisitor
{
public override void VisitModuleDefinition(ModuleDefinition module)
{
module.Types.Accept(this);
VisitCollection(module.Types);
}
public override void VisitTypeDefinition(TypeDefinition type)
{
type.Methods.Accept(this);
VisitCollection(type.Methods);
}
public override void VisitMethodDefinition(MethodDefinition method)
{
//Debug.WriteLine(string.Format("Visited method: {0}::{1}", method.DeclaringType.FullName, method.Name));
var codeVisitor = new CodeVisitor { Method = method };
if (method.HasBody)
method.Body.Accept(codeVisitor);
}
}
private class CodeVisitor : BaseCodeVisitor
{
public MethodDefinition Method;
public override void VisitMethodBody(MethodBody body)
{
body.Instructions.Accept(this);
}
public override void VisitInstructionCollection(InstructionCollection instructions)
{
// optimization to get only instructions that are calls
var instructionsToProcess = from instr in instructions.Cast<Instruction>()
where instr.OpCode.FlowControl == FlowControl.Call
select instr;
instructionsToProcess.ForEach(i => i.Accept(this));
}
public override void VisitInstruction(Instruction instr)
{
switch (instr.OpCode.Code)
{
case Code.Call:
case Code.Callvirt:
case Code.Newobj:
var method = (MethodReference)instr.Operand;
if (!method.DeclaringType.FullName.StartsWith("Milliman.MGAlfa."))
return;
for (var i = method.Parameters.Count - 1; i >= 0; i--)
{
// try to look backwards in the instructions to see if we can find pushes that look like loading of nulls.
instr = instr.Previous;
if (instr.OpCode.Code != Code.Ldnull)
continue;
if (!method.Parameters[i].HasCustomAttributes)
continue;
var reflVisitor = new ReflectionVisitor
{
CallingMethod = Method,
CalledMethod = method,
Parameter = method.Parameters[i]
};
method.Parameters[i].Accept(reflVisitor);
}
break;
}
}
}
private class ReflectionVisitor : BaseReflectionVisitor
{
public MethodDefinition CallingMethod;
public MethodReference CalledMethod;
public ParameterDefinition Parameter;
public override void VisitCustomAttributeCollection(CustomAttributeCollection customAttrs)
{
VisitCollection(customAttrs);
}
private static string _FormatTypeName(TypeReference typeRef)
{
return typeRef.FullName.Replace("Milliman.MGAlfa.", "");
}
private static string _FormatMethod(MethodReference method)
{
var sb = new StringBuilder();
sb.AppendFormat("{0}::{1}(",
_FormatTypeName(method.DeclaringType),
method.Name);
var paramCount = method.Parameters.Count;
if (paramCount > 0)
{
sb.AppendLine();
for (var i = 0; i < paramCount - 1; i++)
sb.AppendFormat(" {0},\n", _FormatTypeName(method.Parameters[i].ParameterType));
sb.AppendFormat(" {0})", _FormatTypeName(method.Parameters[paramCount - 1].ParameterType));
}
return sb.ToString();
}
public override void VisitCustomAttribute(CustomAttribute customAttr)
{
if (customAttr.Constructor.DeclaringType.Name == "NotNullAttribute")
{
Assert.Fail(
"\nMethod {0}\n calls {1}\n with a null value\n for NotNull parameter {2}",
_FormatMethod(CallingMethod),
_FormatMethod(CalledMethod),
Parameter.Name);
}
}
}
private static void _ProcessAssembly(AssemblyDefinition assyDef, IReflectionVisitor visitor, HashSet<string> alreadyProcessedDictionary)
{
if (alreadyProcessedDictionary.Contains(assyDef.Name.FullName))
return;
alreadyProcessedDictionary.Add(assyDef.Name.FullName);
assyDef.MainModule.Accept(visitor);
foreach (var assyNameRef in assyDef.MainModule.AssemblyReferences.Cast<AssemblyNameReference>())
{
if (!assyNameRef.Name.StartsWith("Milliman.MGAlfa.", StringComparison.CurrentCultureIgnoreCase))
continue;
try
{
var assy = Assembly.Load(new AssemblyName(assyNameRef.FullName));
var location = assy.Location;
var assyDef2 = AssemblyFactory.GetAssembly(location);
assyDef2.Resolver = assyDef.Resolver;
_ProcessAssembly(assyDef2, visitor, alreadyProcessedDictionary);
}
catch (AssertionException)
{
throw;
}
catch (Exception e)
{
Debug.WriteLine(string.Format("Couldn't load assembly {0} with error {1}", assyNameRef.FullName, e));
}
}
}
#endregion
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment