Skip to content

Instantly share code, notes, and snippets.

@mttchpmn
Created March 8, 2023 02:13
Show Gist options
  • Save mttchpmn/6a5d923afa1cdb2d403fc9655398d48a to your computer and use it in GitHub Desktop.
Save mttchpmn/6a5d923afa1cdb2d403fc9655398d48a to your computer and use it in GitHub Desktop.
Unit Test Src Generator
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using GenerationAssembly;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
namespace CDD.Unit.Tests.SourceGenerators
{
[Generator]
public class PartialTestClassGenerator : ISourceGenerator
{
public void Initialize(GeneratorInitializationContext context)
{
context.RegisterForSyntaxNotifications(() => new TestSubjectSyntaxReceiver());
}
public void Execute(GeneratorExecutionContext context)
{
var receiver = context.SyntaxReceiver as TestSubjectSyntaxReceiver;
var fields = receiver?.FieldDeclarations;
if (fields is null)
return;
foreach (var field in fields)
{
GeneratePartialTestClass(context, field);
}
}
/// <summary>
/// Generates a partial test class, instantiating the test subject,
/// and generating Mocks for the required constructor parameters.
/// <br /><br />
/// Each Mock that is instantiated will also have `SetupXXX()` and
/// `VerifyXXX()` methods generated for every method defined on the interface.
/// <br /><br />
/// For any test related setup, please define a method called 'Initialize' with the following signature:
/// <code>partial void Initialize()</code>
/// The generated test class will create a parameterless constructor which
/// then calls this Initialize method. Note - creating an Initialize method is optional.
/// </summary>
private void GeneratePartialTestClass(GeneratorExecutionContext context, FieldDeclarationSyntax field)
{
var semanticModel = GetSemanticModel(context, field);
var testSubjectTypeSymbol = GetTestSubjectTypeSymbol(semanticModel, field);
var testSubjectVariableDeclaration = GetTestSubjectVariableDeclaration(semanticModel, field);
var testSubjectConstructorParameters = GetTestSubjectConstructorParameters(testSubjectTypeSymbol).ToList();
// Generate text components required for partial class
var namespaceName = GetNamespaceName(testSubjectVariableDeclaration);
var usingStatements = GetUsingStatements(testSubjectTypeSymbol, testSubjectConstructorParameters);
var className = GetClassName(testSubjectVariableDeclaration);
var fieldDeclarations = GetFieldDeclarations(testSubjectConstructorParameters);
var constructorInstantiation = GetConstructorInstantiation(testSubjectVariableDeclaration, testSubjectTypeSymbol, testSubjectConstructorParameters);
var helperMethods = GetHelperMethods(testSubjectConstructorParameters);
// Generate partial class
var sourceText = GenerateSourceText(
usingStatements,
namespaceName,
className,
fieldDeclarations,
constructorInstantiation,
helperMethods);
context.AddSource($"{className}.generated.cs", sourceText);
}
private IEnumerable<IParameterSymbol> GetTestSubjectConstructorParameters(INamedTypeSymbol testSubjectVariableDeclaration)
{
var constructors = testSubjectVariableDeclaration.Constructors;
if (!constructors.Any())
throw new InvalidOperationException("Unable to obtain constructor for test subject. Ensure you are using a concrete type and not an interface");
if (constructors.Length > 1)
throw new InvalidOperationException("Encountered more than one constructor for test subject");
return constructors.Single().Parameters.ToList();
}
private SemanticModel GetSemanticModel(GeneratorExecutionContext context, FieldDeclarationSyntax field)
=> context.Compilation.GetSemanticModel(field.Declaration.Type.SyntaxTree);
private INamedTypeSymbol GetTestSubjectTypeSymbol(SemanticModel semanticModel, FieldDeclarationSyntax field)
=> semanticModel.GetTypeInfo(field.Declaration.Type).Type as INamedTypeSymbol ?? throw new InvalidOperationException("Unable to obtain type symbol for test subject");
private ISymbol GetTestSubjectVariableDeclaration(SemanticModel semanticModel, FieldDeclarationSyntax field)
{
if (field.Declaration.Variables.Count > 1)
throw new InvalidOperationException("Encountered more than one variable for field declaration");
var result = semanticModel.GetDeclaredSymbol(field.Declaration.Variables.First());
if (result is null)
throw new InvalidOperationException("Unable to obtain test subject variable declaration");
return result;
}
private string GetUsingStatements(ISymbol testSubjectSymbol, List<IParameterSymbol> testSubjectConstructorParameters)
{
var namespaceForType = $"using {testSubjectSymbol.ContainingNamespace.ToDisplayString()};";
var namespaces = testSubjectConstructorParameters.Select(GetAssemblyForParameter).ToList();
namespaces.Add(namespaceForType);
return string.Join("\n", namespaces.Distinct());
}
private string GetAssemblyForParameter(IParameterSymbol parameter)
{
var namespc = parameter.Type.ContainingNamespace;
return $"using {namespc.ToDisplayString()};";
}
private string GetNamespaceName(ISymbol declaration)
=> declaration.ContainingNamespace.ToDisplayString();
private string GetClassName(ISymbol testSubjectVariableDeclaration)
=> testSubjectVariableDeclaration.ContainingType.Name;
private string GetFieldDeclarations(IEnumerable<IParameterSymbol> testSubjectConstructorParameters)
{
var fields = testSubjectConstructorParameters.Select(
x =>
{
var mockType = x.Type as INamedTypeSymbol;
var genericType = mockType?.TypeArguments.FirstOrDefault();
var genericTypeText = genericType != null
? $"<{genericType}>"
: "";
return $"private Mock<{x.Type.Name}{genericTypeText}> {GetFieldName(x.Type.Name)} = new();";
});
return string.Join("\n\t", fields);
}
private string GetFieldName(string parameterName)
=> "_" + parameterName[1].ToString().ToLower() + parameterName.Substring(2);
private string GetConstructorInstantiation(ISymbol testSubjectVariableDeclaration, INamedTypeSymbol testSubjectTypeSymbol, IEnumerable<IParameterSymbol> testSubjectConstructorParameters)
{
var parameters = testSubjectConstructorParameters.Select(x => $"{GetFieldName(x.Type.Name)}.Object");
var parameterList = string.Join(", ", parameters);
return $"{testSubjectVariableDeclaration.Name} = new {testSubjectTypeSymbol.Name}({parameterList});";
}
private string GetHelperMethods(IEnumerable<IParameterSymbol> testSubjectConstructorParameters)
{
var setupMethods = testSubjectConstructorParameters.Select(GenerateHelperMethodsForParameter);
return string.Join("\n\n\t", setupMethods);
}
private string GenerateHelperMethodsForParameter(IParameterSymbol parameter)
{
var param = parameter.Type as INamedTypeSymbol;
var availableMethods = param?.GetMembers().Select(x => x as IMethodSymbol).Where(x => x != null).ToList();
if (availableMethods is null)
throw new InvalidOperationException($"Unable to obtain methods for parameter: {parameter}");
var setupMethods = availableMethods.Where(x => x != null && !x.ReturnType.Name.Equals("Void")).Select(x => GenerateSetupMethod(parameter, x));
var verifyMethods = availableMethods.Select(x => GenerateVerifyMethod(parameter, x));
var setupText = string.Join("\n\n\t", setupMethods);
var verifyText = string.Join("\n\n\t", verifyMethods);
if (string.IsNullOrWhiteSpace(setupText) && string.IsNullOrWhiteSpace(verifyText))
return "";
return $"#region {parameter.Type.Name} helper methods:\n\t" + setupText + "\n\n\t" + verifyText + "\n\t#endregion";
}
private string GenerateSetupMethod(IParameterSymbol parameter, IMethodSymbol method)
{
if (method is null)
return "";
var methodParameters = method.Parameters.Select((x, y) => $"It.Is<{x.ToDisplayString()}>(y => param{y + 1} == null || y == param{y + 1})");
var parametersText = string.Join(", ", methodParameters);
var nullableParametersText = GetNullableParametersText(method);
var fieldName = GetFieldName(parameter.Type.Name);
var returnType = method.ReturnType.ToDisplayString();
return $@"private void Setup{method.Name}({returnType} returnValue{nullableParametersText})
{{
{fieldName}.Setup(x => x.{method.Name}({parametersText})).Returns(returnValue);
}}";
}
private string GenerateVerifyMethod(IParameterSymbol parameter, IMethodSymbol method)
{
if (method is null)
return "";
var methodParameters = method.Parameters.Select((x, y) => $"param{y + 1} ?? It.IsAny<{x.ToDisplayString()}>()");
var parametersText = string.Join(", ", methodParameters);
var nullableParametersText = GetNullableParametersText(method);
var fieldName = GetFieldName(parameter.Type.Name);
return $@"private void Verify{method.Name}(Times? timesCalled = null{nullableParametersText})
{{
{fieldName}.Verify(x => x.{method.Name}({parametersText}), timesCalled ?? Times.AtLeastOnce());
}}";
}
private static string GetNullableParametersText(IMethodSymbol method)
{
var nullableParameters = method.Parameters
.Select((x, y) => $"{x.ToDisplayString()}{((x.IsOptional || x.NullableAnnotation == NullableAnnotation.Annotated) ? "" : "?")} param{y + 1} = null");
return method.Parameters.Length > 0
? ", " + string.Join(", ", nullableParameters)
: string.Empty;
}
private SourceText GenerateSourceText(
string usingStatements,
string namespaceName,
string className,
string fieldDeclarations,
string constructorInstantiation,
string helperMethods)
{
return SourceText.From(
$@"// <auto-generated>
#pragma warning disable CS8073
#nullable enable
using System;
using Moq;
{usingStatements}
namespace {namespaceName};
public partial class {className}
{{
{fieldDeclarations}
public {className}()
{{
{constructorInstantiation}
Initialize();
}}
partial void Initialize();
{helperMethods}
}}
#pragma warning restore CS8073
",
Encoding.UTF8);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment