Skip to content

Instantly share code, notes, and snippets.

@haacked
Last active May 14, 2023 18:38
Show Gist options
  • Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Save haacked/00de560d00692b7f4859336c747af10e to your computer and use it in GitHub Desktop.
Roslyn Analyzer to warn about access to forbidden types
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;
// CREDIT: https://github.com/dotnet/roslyn-analyzers/blob/master/src/Microsoft.CodeAnalysis.BannedApiAnalyzers/Core/SymbolIsBannedAnalyzer.cs
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class ForbiddenTypeAnalyzer : DiagnosticAnalyzer
{
public const string DiagnosticId = nameof(ForbiddenTypeAnalyzer);
static readonly LocalizableString Description = "Restricts the set of types that may be used.";
const string Title = "Forbidden Type Analyzer";
const string MessageFormat = "Access to type {0} is forbidden.";
const string Category = "API Usage";
readonly HashSet<string> _forbiddenTypeNames;
static readonly IEnumerable<string> DefaultTypeAccessDenyList = new[]
{
"System.Console",
"System.Environment",
"System.IntPtr",
"System.Type"
};
static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(
DiagnosticId,
Title,
MessageFormat,
Category,
DiagnosticSeverity.Warning,
isEnabledByDefault: true,
Description);
public ForbiddenTypeAnalyzer() : this(DefaultTypeAccessDenyList)
{
}
ForbiddenTypeAnalyzer(IEnumerable<string> forbiddenTypeNames)
{
_forbiddenTypeNames = forbiddenTypeNames.ToHashSet(StringComparer.Ordinal);
}
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(Rule);
public override void Initialize(AnalysisContext compilationContext)
{
compilationContext.EnableConcurrentExecution();
compilationContext.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics);
compilationContext.RegisterOperationAction(context =>
{
var type = context.Operation switch
{
IObjectCreationOperation objectCreation => objectCreation.Type,
IInvocationOperation invocationOperation => invocationOperation.TargetMethod.ContainingType,
IMemberReferenceOperation memberReference => memberReference.Member.ContainingType,
IArrayCreationOperation arrayCreation => arrayCreation.Type,
IAddressOfOperation addressOf => addressOf.Type,
IConversionOperation conversion => conversion.OperatorMethod?.ContainingType,
IUnaryOperation unary => unary.OperatorMethod?.ContainingType,
IBinaryOperation binary => binary.OperatorMethod?.ContainingType,
IIncrementOrDecrementOperation incrementOrDecrement => incrementOrDecrement.OperatorMethod?.ContainingType,
_ => throw new NotImplementedException($"Unhandled OperationKind: {context.Operation.Kind}")
};
VerifyType(context.ReportDiagnostic, type, context.Operation.Syntax);
},
OperationKind.ObjectCreation,
OperationKind.Invocation,
OperationKind.EventReference,
OperationKind.FieldReference,
OperationKind.MethodReference,
OperationKind.PropertyReference,
OperationKind.ArrayCreation,
OperationKind.AddressOf,
OperationKind.Conversion,
OperationKind.UnaryOperator,
OperationKind.BinaryOperator,
OperationKind.Increment,
OperationKind.Decrement);
}
bool VerifyType(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode)
{
do
{
if (!VerifyTypeArguments(reportDiagnostic, type, syntaxNode, out type))
{
return false;
}
if (type is null)
{
// Type will be null for arrays and pointers.
return true;
}
var typeName = type.ToString();
if (typeName is null)
{
return true;
}
if (_forbiddenTypeNames.Contains(typeName))
{
reportDiagnostic(Diagnostic.Create(Rule, syntaxNode.GetLocation(), typeName));
return false;
}
type = type.ContainingType;
}
while (!(type is null));
return true;
}
bool VerifyTypeArguments(Action<Diagnostic> reportDiagnostic, ITypeSymbol? type, SyntaxNode syntaxNode, out ITypeSymbol? originalDefinition)
{
switch (type)
{
case INamedTypeSymbol namedTypeSymbol:
originalDefinition = namedTypeSymbol.ConstructedFrom;
foreach (var typeArgument in namedTypeSymbol.TypeArguments)
{
if (typeArgument.TypeKind != TypeKind.TypeParameter &&
typeArgument.TypeKind != TypeKind.Error &&
!VerifyType(reportDiagnostic, typeArgument, syntaxNode))
{
return false;
}
}
break;
case IArrayTypeSymbol arrayTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, arrayTypeSymbol.ElementType, syntaxNode);
case IPointerTypeSymbol pointerTypeSymbol:
originalDefinition = null;
return VerifyType(reportDiagnostic, pointerTypeSymbol.PointedAtType, syntaxNode);
default:
originalDefinition = type?.OriginalDefinition;
break;
}
return true;
}
}
// UNIT TESTS
using System.Collections.Immutable;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Scripting;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Scripting;
using Xunit;
public class ForbiddenTypeAnalyzerTests
{
[Theory]
[InlineData(@"if (IsRequest) {
Reply(""test"");
}
else {
var env = Environment.CommandLine;
Reply(env);
}", "System.Environment")]
[InlineData(@"if (IsRequest) {
Reply(""test"");
}
else {
var env = Environment.GetEnvironmentVariable(""test"");
Reply(env);
}", "System.Environment")]
[InlineData(@"var args = Environment.CommandLine;", "System.Environment")]
[InlineData(@"Console.WriteLine(""test"");", "System.Console")]
[InlineData(@"var ptr = new IntPtr(32);", "System.IntPtr")]
[InlineData(@"var type = (Type)SomeType;", "System.Type")]
[InlineData(@"var type = (Type)SomeType; var name = type.Name;", "System.Type")]
[InlineData(@"var type = Type.GetType(""name"");", "System.Type")]
[InlineData(@"var pointers = new IntPtr[4];", "System.IntPtr")]
public async Task ReturnsErrorsForForbiddenTypes(string code, string expectedForbiddenType)
{
var options = ScriptOptions.Default
.WithImports("System")
.WithEmitDebugInformation(true)
.WithReferences("System.Runtime.Extensions", "System.Console")
.WithAllowUnsafe(false);
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
var compilation = script.GetCompilation();
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
new ForbiddenTypeAnalyzer());
var compilationWithAnalyzers = new CompilationWithAnalyzers(
compilation,
analyzers,
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
CancellationToken.None);
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync();
var diagnostic = Assert.Single(diagnosticResults);
Assert.NotNull(diagnostic);
Assert.Equal(
$"Access to type {expectedForbiddenType} is forbidden.",
diagnostic.GetMessage());
Assert.Equal(ForbiddenTypeAnalyzer.DiagnosticId, diagnostic.Id);
}
[Fact]
public async Task DoesNotReturnsErrorsForAllowedTypes()
{
const string code = @"if (IsRequest) {
Reply(""test"");
}
else {
var rnd = new Random();
Reply(rnd.Next(1).ToString());
}";
var options = ScriptOptions.Default
.WithImports("System")
.WithEmitDebugInformation(true)
.WithReferences("System.Runtime.Extensions")
.WithAllowUnsafe(false);
var script = CSharpScript.Create<dynamic>(code, globalsType: typeof(IScriptGlobals), options: options);
var compilation = script.GetCompilation();
var analyzers = ImmutableArray.Create<DiagnosticAnalyzer>(
new ForbiddenTypeAnalyzer());
var compilationWithAnalyzers = new CompilationWithAnalyzers(
compilation,
analyzers,
new AnalyzerOptions(ImmutableArray<AdditionalText>.Empty),
CancellationToken.None);
var diagnosticResults = await compilationWithAnalyzers.GetAllDiagnosticsAsync();
Assert.Empty(diagnosticResults);
}
public interface IScriptGlobals
{
bool IsRequest { get; }
void Reply(string reply);
}
}
@haacked
Copy link
Author

haacked commented Oct 29, 2020

To clarify, I've updated my repro a tiny bit to make sure it imports System.IO and references System.Runtime.Extensions. Then I added the following two lines to Program.cs

await TestCode(@"var x = new StreamWriter(""some-path"");");
await TestCode(@"new StreamWriter(""some-path"").Flush();");

And the result is

Compiling `var x = new StreamWriter("some-path");` resulted in 0 diagnostics.
Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.

Contrary to what I said earlier (I messed up my testing). I'm not sure why my code never breaks into the RegisterOperationAction callback.

@haacked
Copy link
Author

haacked commented Oct 29, 2020

I forgot this was in the scripting dialect and not a top level function.

Yeah, I can't use C# 9 just yet. 😦

I changed the code to use the regular compilation.

var syntaxTree = CSharpSyntaxTree.ParseText(code);
        var compilation = CSharpCompilation.Create(
            "assemblyName",
            new[] { syntaxTree },
            references,
            new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));

And then ran this...

await TestCode(@"
using System.IO;

public class Test
{
    public void DoStuff()
    {
        new StreamWriter(""test"").Flush();
    }
}
");

And got the result I expected (8,9): warning ForbiddenTypeAnalyzer: Access to type System.IO.StreamWriter is forbidden..

So it does appear to be a difference between the C# scripting dialect and C#.

@haacked
Copy link
Author

haacked commented Oct 30, 2020

Just for completeness sake, I tried this out with CSharpSymbolIsBannedAnalyzer. Here's the symbols file I used.

T:System.Console;Don't use System.Console
T:System.Environment;Don't use System.Environment
T:System.Type;Don't use System.Type
T:System.Reflection.MemberInfo;Don't use Reflection.
T:System.IO.StreamWriter;Don't use System.IO.StreamWriter.

Here are the test cases I used (same program as above but with the CSharpSymbolIsBannedAnalyzer analyzer swapped in.

await TestCode(@"new StreamWriter(""some-path"").Flush();");
await TestCode("var env = Environment.CommandLine;");
await TestCode("var type = (Type)SomeType; var name = type.Name;");
await TestCode(@"Console.WriteLine(""test"");");

And here are the results.

Compiling `new StreamWriter("some-path").Flush();` resulted in 0 diagnostics.
Compiling `var env = Environment.CommandLine;` resulted in 1 diagnostics.
(1,11): warning RS0030: The symbol 'Environment' is banned in this project: Don't use System.Environment
Compiling `var type = (Type)SomeType; var name = type.Name;` resulted in 1 diagnostics.
(1,39): warning RS0030: The symbol 'MemberInfo' is banned in this project: Don't use Reflection.
Compiling `Console.WriteLine("test");` resulted in 0 diagnostics.

It seems like there are cases where the CSharpSymbolIsBannedAnalyzer doesn't work with CSharpScript. I'm curious to understand why. Is it a bug? By design?

@haacked
Copy link
Author

haacked commented Nov 1, 2020

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment