Skip to content

Instantly share code, notes, and snippets.

@haacked
Last active May 14, 2023 18:38
Show Gist options
  • Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Roslyn Analyzer to warn about access to forbidden types
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
// CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer
{
public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer);
static readonly LocalizableString Description = "Restricts the set of types that may be used.";
const string Title = "Forbidden Type Analyzer";
const string MessageFormat = "Access to type {0} is forbidden.";
const string Category = "API Usage";
readonly HashSet<string> _forbiddenTypeNames;
static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[]
{
"System.Console",
"System.Environment",
"System.IntPtr",
"System.Type"
};
static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(
DiagnosticId,
Title,
MessageFormat,
Category,
DiagnosticSeverity.Warning,
isEnabledByDefault: true,
Description);
public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList)
{
}
ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames)
{
_forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal);
}
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule);
public override void Initialize(AnalysisContext compilationContext)
{
compilationContext.EnableConcurrentExecution();
compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
compilationContext.RegisterOperationAction(context =>
{
var type = context.Operation switch
{
IObjectCreationOperation objectCreation => objectCreation.Type,
IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType,
IMemberReferenceOperation memberReference => memberReference.Member.ContainingType,
IArrayCreationOperation arrayCreation => arrayCreation.Type,
IAddressOfOperation addressOf => addressOf.Type,
IConversionOperation conversion => conversion.OperatorMethod?.ContainingType,
IUnaryOperation unary => unary.OperatorMethod?.ContainingType,
IBinaryOperation binary => binary.OperatorMethod?.ContainingType,
IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType,
_ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}")
};
VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax);
},
OperationKind.ObjectCreation,
OperationKind.Invocation,
OperationKind.EventReference,
OperationKind.FieldReference,
OperationKind.MethodReference,
OperationKind.PropertyReference,
OperationKind.ArrayCreation,
OperationKind.AddressOf,
OperationKind.Conversion,
OperationKind.UnaryOperator,
OperationKind.BinaryOperator,
OperationKind.Increment,
OperationKind.Decrement);
}
bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode)
{
do
{
if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type))
{
return false;
}
if (type is null)
{
// Type will be null for arrays and pointers.
return true;
}
var typeName = type.ToString();
if (typeName is null)
{
return true;
}
if (_forbiddenTypeNames.Contains(typeName))
{
reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName));
return false;
}
type = type.ContainingType;
}
while (!(type is null));
return true;
}
bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition)
{
switch (type)
{
case INamedTypeSymbol namedTypeSymbol:
originalDefinition = namedTypeSymbol.ConstructedFrom;
foreach (var typeArgument in namedTypeSymbol.TypeArguments)
{
if (typeArgument.TypeKind != TypeKind.TypeParameter &&
typeArgument.TypeKind != TypeKind.Error &&
!VerifyType(reportDiagnostic, typeArgument, syntaxNode))
{
return false;
}
}
break;
case IArrayTypeSymbol arrayTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode);
case IPointerTypeSymbol pointerTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode);
default:
originalDefinition = type?.OriginalDefinition;
break;
}
return true;
}
}
// UNIT TESTS
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Scripting;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Scripting;
using Xunit;
public class ForbiddenTypeAnalyzerTests
{
[Theory]
[InlineData(@"if (IsRequest) {
Reply(""test"");
}
else {
var env = Environment.CommandLine;
Reply(env);
}", "System.Environment")]
[InlineData(@"if (IsRequest) {
Reply(""test"");
}
else {
var env = Environment.GetEnvironmentVariable(""test"");
Reply(env);
}", "System.Environment")]
[InlineData(@"var args = Environment.CommandLine;", "System.Environment")]
[InlineData(@"Console.WriteLine(""test"");", "System.Console")]
[InlineData(@"var ptr = new IntPtr(32);", "System.IntPtr")]
[InlineData(@"var type = (Type)SomeType;", "System.Type")]
[InlineData(@"var type = (Type)SomeType; var name = type.Name;", "System.Type")]
[InlineData(@"var type = Type.GetType(""name"");", "System.Type")]
[InlineData(@"var pointers = new IntPtr[4];", "System.IntPtr")]
public async Task ReturnsErrorsForForbiddenTypes(string code, string expectedForbiddenType)
{
var options = ScriptOptions.Default
.WithImports("System")
.WithEmitDebugInformation(true)
.WithReferences("System.Runtime.Extensions", "System.Console")
.WithAllowUnsafe(false);
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
var compilation = script.GetCompilation();
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
new ForbiddenTypeAnalyzer());
var compilationWithAnalyzers = new CompilationWithAnalyzers(
compilation,
analyzers,
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
CancellationToken.None);
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync();
var diagnostic = Assert.Single(diagnosticResults);
Assert.NotNull(diagnostic);
Assert.Equal(
$"Access to type {expectedForbiddenType} is forbidden.",
diagnostic.GetMessage());
Assert.Equal(ForbiddenTypeAnalyzer.DiagnosticId, diagnostic.Id);
}
[Fact]
public async Task DoesNotReturnsErrorsForAllowedTypes()
{
const string code = @"if (IsRequest) {
Reply(""test"");
}
else {
var rnd = new Random();
Reply(rnd.Next(1).ToString());
}";
var options = ScriptOptions.Default
.WithImports("System")
.WithEmitDebugInformation(true)
.WithReferences("System.Runtime.Extensions")
.WithAllowUnsafe(false);
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
var compilation = script.GetCompilation();
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
new ForbiddenTypeAnalyzer());
var compilationWithAnalyzers = new CompilationWithAnalyzers(
compilation,
analyzers,
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
CancellationToken.None);
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync();
Assert.Empty(diagnosticResults);
}
public interface IScriptGlobals
{
bool IsRequest { get; }
void Reply(string reply);
}
}
@haacked
Copy link
Author

haacked commented Oct 29, 2020

I have a minimal repro.

AnalyzerDemo.csproj

<Project Sdk="Microsoft.NET.Sdk">

    <PropertyGroup>
        <OutputType>Exe</OutputType>
        <TargetFramework>netcoreapp3.1</TargetFramework>
        <nullable>enable</nullable>
    </PropertyGroup>

    <ItemGroup>
        <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="3.3.1" />
        <PackageReference Include="System.Collections.Immutable" Version="1.7.1" />
    </ItemGroup>
</Project>

Program.cs

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Scripting;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Scripting;

class Program
{
    static async Task Main(string[] args)
    {
        await TestCode(@"var env = Environment.CommandLine;");
        await TestCode(@"Console.WriteLine(""test"");");
    }

    static async Task TestCode(string code)
    {
        var diagnostics = await CompileCode(code);
        Console.WriteLine($"Compiling `{code}` resulted in {diagnostics.Length} diagnostics.");
        foreach (var diagnostic in diagnostics)
        {
            Console.WriteLine(diagnostic.ToString());
        }
    }

    static async Task<ImmutableArray<Diagnostic>> CompileCode(string code)
    {
        static IEnumerable<string> GetSystemAssemblyPaths()
        {
            var assemblyPath = Path.GetDirectoryName(typeof(object).Assembly.Location)
                               ?? throw new InvalidOperationException("Could not find the assembly for object.");
            yield return Path.Combine(assemblyPath, "mscorlib.dll");
            yield return Path.Combine(assemblyPath, "System.dll");
            yield return Path.Combine(assemblyPath, "System.Core.dll");
            yield return Path.Combine(assemblyPath, "System.Console.dll");
            yield return Path.Combine(assemblyPath, "System.Runtime.dll");
            yield return Path.Combine(assemblyPath, "System.Runtime.Extensions.dll");
        }
        
        var references = GetSystemAssemblyPaths()
            .Select(path => MetadataReference.CreateFromFile(path));
        
        var options = ScriptOptions.Default
            .WithImports("System", "System.IO")
            .WithEmitDebugInformation(true)
            .WithReferences(references)
            .WithAllowUnsafe(false);
        var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
        var compilation = script.GetCompilation();
        var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
            new ForbiddenTypeAnalyzer());
        var compilationWithAnalyzers = new CompilationWithAnalyzers(
            compilation,
            analyzers,
            new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
            CancellationToken.None);

        return await compilationWithAnalyzers.GetAllDiagnosticsAsync();
    }
}

public interface IScriptGlobals
{
}

ForbiddenTypeAnalyzer.cs

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;

// CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer
{
    public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer);
    
    static readonly LocalizableString Description = "Restricts the set of types that may be used.";
    const string Title = "Forbidden Type Analyzer";
    const string MessageFormat = "Access to type {0} is forbidden.";
    const string Category = "API Usage";
    readonly HashSet<string> _forbiddenTypeNames;
    
    static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[]
    {
        "System.IO.StreamWriter",
        "System.Console",
        "System.Environment",
        "System.IntPtr",
        "System.Type"
    };
    
    static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(
        DiagnosticId,
        Title,
        MessageFormat,
        Category,
        DiagnosticSeverity.Warning,
        isEnabledByDefault: true, 
        Description);

    public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList)
    {
    }

    ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames)
    {
        _forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal);
    }

    public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule);

    public override void Initialize(AnalysisContext compilationContext)
    {
        compilationContext.EnableConcurrentExecution();
        compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze |
                                                          GeneratedCodeAnalysisFlags.ReportDiagnostics);

        compilationContext.RegisterCompilationStartAction(OnCompilationStart);
    }
    
    void OnCompilationStart(CompilationStartAnalysisContext compilationContext)
    {
        compilationContext.RegisterOperationAction(context =>
            {
                var type = context.Operation switch
                {
                    IObjectCreationOperation objectCreation => objectCreation.Type,
                    IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType,
                    IMemberReferenceOperation memberReference => memberReference.Member?.ContainingType,
                    IArrayCreationOperation arrayCreation => arrayCreation.Type,
                    IAddressOfOperation addressOf => addressOf.Type,
                    IConversionOperation conversion => conversion.OperatorMethod?.ContainingType,
                    IUnaryOperation unary => unary.OperatorMethod?.ContainingType,
                    IBinaryOperation binary => binary.OperatorMethod?.ContainingType,
                    IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType,
                    _ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}")
                };

                VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax);
            },
            OperationKind.ObjectCreation,
            OperationKind.Invocation,
            OperationKind.EventReference,
            OperationKind.FieldReference,
            OperationKind.MethodReference,
            OperationKind.PropertyReference,
            OperationKind.ArrayCreation,
            OperationKind.AddressOf,
            OperationKind.Conversion,
            OperationKind.UnaryOperator,
            OperationKind.BinaryOperator,
            OperationKind.Increment,
            OperationKind.Decrement);
    }

    bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode)
    {
        do
        {
            if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type))
            {
                return false;
            }

            if (type is null)
            {
                // Type will be null for arrays and pointers.
                return true;
            }

            var typeName = type.ToString();
            if (typeName is null)
            {
                return true;
            }

            if (_forbiddenTypeNames.Contains(typeName))
            {
                reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName));
                return false;
            }

            type = type.ContainingType;
        }
        while (!(type is null));

        return true;
    }
    
    bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition)
    {
        switch (type)
        {
            case INamedTypeSymbol namedTypeSymbol:
                originalDefinition = namedTypeSymbol.ConstructedFrom;
                foreach (var typeArgument in namedTypeSymbol.TypeArguments)
                {
                    if (typeArgument.TypeKind != TypeKind.TypeParameter &&
                        typeArgument.TypeKind != TypeKind.Error &&
                        !VerifyType(reportDiagnostic, typeArgument, syntaxNode))
                    {
                        return false;
                    }
                }
                break;

            case IArrayTypeSymbol arrayTypeSymbol:
                originalDefinition = null;
                return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode);

            case IPointerTypeSymbol pointerTypeSymbol:
                originalDefinition = null;
                return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode);

            default:
                originalDefinition = type?.OriginalDefinition;
                break;

        }

        return true;
    }
}

The output is

Compiling `var env = Environment.CommandLine;` resulted in 1 diagnostics.
(1,11): warning ForbiddenTypeAnalyzer: Access to type System.Environment is forbidden.
Compiling `Console.WriteLine("test");` resulted in 0 diagnostics.

@haacked
Copy link
Author

haacked commented Oct 29, 2020

Just to be sure the code is compiling correctly, I defined a concrete ScriptGlobals class.

public class ScriptGlobals : IScriptGlobals
{
}

And then tried the following right after I create the script in Program.cs.

await script.RunAsync(new ScriptGlobals());

And I see the string "test" in the console. So the script is running fine. When I step through the debugger, the RegisterOperationAction is not called when calling await TestCode(@"Console.WriteLine(""test"");");.

@haacked
Copy link
Author

haacked commented Oct 29, 2020

Here's a couple more cases that are strange.

var write = new StreamWriter("some-path");  // Is flagged by the analyzer
new StreamWriter("some-path");              // Is NOT flagged by the analyzer.

@jmarolf
Copy link

jmarolf commented Oct 29, 2020

new StreamWriter("some-path"); isn't valid code though right?

@haacked
Copy link
Author

haacked commented Oct 29, 2020

@jmarolf

How about

new StreamWriter("some-path").Flush();

That also doesn't trigger the analyzer.

Also, if it's not valid code, I would expect Roslyn would report it. It compiles without any diagnostics. Could it be something I'm doing must be preventing Roslyn from reporting errors?

@jmarolf
Copy link

jmarolf commented Oct 29, 2020

I forgot this was in the scripting dialect and not a top level function. @jaredpar what are the language rules for these cases?

@jmarolf
Copy link

jmarolf commented Oct 29, 2020

Well regardless, I'll just need to take some time this evening to debug through this instead of guessing.

@haacked
Copy link
Author

haacked commented Oct 29, 2020

To clarify, I've updated my repro a tiny bit to make sure it imports System.IO and references System.Runtime.Extensions. Then I added the following two lines to Program.cs

await TestCode(@"var x = new StreamWriter(""some-path"");");
await TestCode(@"new StreamWriter(""some-path"").Flush();");

And the result is

Compiling `var x = new StreamWriter("some-path");` resulted in 0 diagnostics.
Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.

Contrary to what I said earlier (I messed up my testing). I'm not sure why my code never breaks into the RegisterOperationAction callback.

@haacked
Copy link
Author

haacked commented Oct 29, 2020

I forgot this was in the scripting dialect and not a top level function.

Yeah, I can't use C# 9 just yet. 😦

I changed the code to use the regular compilation.

var syntaxTree = CSharpSyntaxTree.ParseText(code);
        var compilation = CSharpCompilation.Create(
            "assemblyName",
            new[] { syntaxTree },
            references,
            new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));

And then ran this...

await TestCode(@"
using System.IO;

public class Test
{
    public void DoStuff()
    {
        new StreamWriter(""test"").Flush();
    }
}
");

And got the result I expected (8,9): warning ForbiddenTypeAnalyzer: Access to type System.IO.StreamWriter is forbidden..

So it does appear to be a difference between the C# scripting dialect and C#.

@haacked
Copy link
Author

haacked commented Oct 30, 2020

Just for completeness sake, I tried this out with CSharpSymbolIsBannedAnalyzer. Here's the symbols file I used.

T:System.Console;Don't use System.Console
T:System.Environment;Don't use System.Environment
T:System.Type;Don't use System.Type
T:System.Reflection.MemberInfo;Don't use Reflection.
T:System.IO.StreamWriter;Don't use System.IO.StreamWriter.

Here are the test cases I used (same program as above but with the CSharpSymbolIsBannedAnalyzer analyzer swapped in.

await TestCode(@"new StreamWriter(""some-path"").Flush();");
await TestCode("var env = Environment.CommandLine;");
await TestCode("var type = (Type)SomeType; var name = type.Name;");
await TestCode(@"Console.WriteLine(""test"");");

And here are the results.

Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.
Compiling `var env = Environment.CommandLine;` resulted in 1 diagnostics.
(1,11): warning RS0030: The symbol 'Environment' is banned in this project: Don't use System.Environment
Compiling `var type = (Type)SomeType; var name = type.Name;` resulted in 1 diagnostics.
(1,39): warning RS0030: The symbol 'MemberInfo' is banned in this project: Don't use Reflection.
Compiling `Console.WriteLine("test");` resulted in 0 diagnostics.

It seems like there are cases where the CSharpSymbolIsBannedAnalyzer doesn't work with CSharpScript. I'm curious to understand why. Is it a bug? By design?

@haacked
Copy link
Author

haacked commented Nov 1, 2020

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment